Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding correlation distance metric in oneDAL primitives #3059

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4fade17
Adding correlation distance metric
Feb 3, 2025
bb61721
Update metrics.hpp
richardnorth3 Feb 3, 2025
cd7c506
Update metrics.hpp with correlation_metric for GPU
richardnorth3 Feb 14, 2025
2039675
Add correlation distance test
richardnorth3 Feb 20, 2025
b4fd66a
Add supporting files for correlation distance
richardnorth3 Feb 24, 2025
22f03a8
Update metrics.hpp
richardnorth3 Feb 24, 2025
93a1c98
Update distance.hpp with correlation prototypes
richardnorth3 Feb 24, 2025
aec5d45
Merge branch 'dev/correlation-metric' of https://github.com/richardno…
richardnorth3 Feb 24, 2025
a53164d
Update metrics.hpp
richardnorth3 Feb 24, 2025
7ade161
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
c56b209
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
0b45447
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
9b09822
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
cc52792
Update bazel BUILD file with cov deps
richardnorth3 Feb 24, 2025
8232af4
Update bazel BUILD file
richardnorth3 Feb 24, 2025
8670a5a
Update bazel BUILD file
richardnorth3 Feb 24, 2025
7d88579
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
4740ad6
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
069a787
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
325e01b
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
01c4b72
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
f888e7a
Update correlation distance test and header
Feb 25, 2025
71609fc
Update correlation distance test
richardnorth3 Feb 25, 2025
8008137
Update correlation_distance_misc.hpp
richardnorth3 Feb 25, 2025
dbbbc3e
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
a862a74
Update correlation distance test
Feb 25, 2025
68e8c4f
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
012830c
Update correleation distance with bux fixes
Feb 28, 2025
aba32ac
Update correlation_distance_dpc.cpp
richardnorth3 Feb 28, 2025
af5245b
Update correlation_distance_misc.hpp
richardnorth3 Feb 28, 2025
94d0b5c
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 28, 2025
bf56cf3
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 28, 2025
50b56e6
Update distance.hpp
richardnorth3 Feb 28, 2025
ba83212
Update distance.hpp
richardnorth3 Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/oneapi/dal/backend/primitives/distance/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dal_module(
"@onedal//cpp/oneapi/dal/backend/primitives:blas",
"@onedal//cpp/oneapi/dal/backend/primitives:common",
"@onedal//cpp/oneapi/dal/backend/primitives:reduction",
"@onedal//cpp/oneapi/dal/backend/primitives:stat",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/backend/primitives/distance/distance.hpp"
#include "oneapi/dal/backend/primitives/distance/correlation_distance_misc.hpp"

#include "oneapi/dal/backend/primitives/blas.hpp"
#include "oneapi/dal/backend/primitives/reduction/reduction.hpp"

namespace oneapi::dal::backend::primitives {

template <typename Float>
template <ndorder order>
auto distance<Float, correlation_metric<Float>>::get_inversed_norms(const ndview<Float, 2, order>& inp,
const event_vector& deps) const
-> inv_norms_res_t {
return compute_inversed_l2_norms(q_, inp, deps);
}

template <typename Float>
template <ndorder order>
auto distance<Float, correlation_metric<Float>>::get_deviation(const ndview<Float, 2, order>& inp,
const event_vector& deps) const
-> comp_dev_res_t {
return compute_deviation(q_, inp, deps);
}

template <typename Float>
template <ndorder order1, ndorder order2>
sycl::event distance<Float, correlation_metric<Float>>::operator()(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const ndview<Float, 1>& inp1_norms,
const ndview<Float, 1>& inp2_norms,
const event_vector& deps) const {
auto ip_event = compute_correlation_inner_product(q_, inp1, inp2, out, deps);
return finalize_correlation(q_, inp1_norms, inp2_norms, out, { ip_event });
}

template <typename Float>
template <ndorder order1, ndorder order2>
sycl::event distance<Float, correlation_metric<Float>>::operator()(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) const {
auto [centered_inp1, comp_dev1_event] = get_deviation(inp1, deps);
auto [centered_inp2, comp_dev2_event] = get_deviation(inp2, deps);

auto [inv_norms1_array, inv_norms1_event] = get_inversed_norms(centered_inp1, { comp_dev1_event });
auto [inv_norms2_array, inv_norms2_event] = get_inversed_norms(centered_inp2, { comp_dev2_event });
return this->operator()(centered_inp1,
centered_inp2,
out,
inv_norms1_array,
inv_norms2_array,
{ inv_norms1_event, inv_norms2_event });
}

#define INSTANTIATE(F, A, B) \
template sycl::event distance<F, correlation_metric<F>>::operator()(const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const ndview<F, 1>&, \
const ndview<F, 1>&, \
const event_vector&) const; \
template sycl::event distance<F, correlation_metric<F>>::operator()(const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const event_vector&) const;

#define INSTANTIATE_B(F, A) \
INSTANTIATE(F, A, ndorder::c) \
INSTANTIATE(F, A, ndorder::f) \
template std::tuple<ndarray<F, 1>, sycl::event> \
distance<F, correlation_metric<F>>::get_inversed_norms(const ndview<F, 2, A>& inp, \
const event_vector& deps) const; \
template std::tuple<ndarray<F, 2>, sycl::event> \
distance<F, correlation_metric<F>>::get_deviation(const ndview<F, 2, A>& inp, \
const event_vector& deps) const;

#define INSTANTIATE_F(F) \
INSTANTIATE_B(F, ndorder::c) \
INSTANTIATE_B(F, ndorder::f) \
template class distance<F, squared_l2_metric<F>>;

INSTANTIATE_F(float);
INSTANTIATE_F(double);

#undef INSTANTIATE

} // namespace oneapi::dal::backend::primitives
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/backend/primitives/common.hpp"
#include "oneapi/dal/backend/primitives/ndarray.hpp"

#include "oneapi/dal/backend/primitives/distance/distance.hpp"

namespace oneapi::dal::backend::primitives {

#ifdef ONEDAL_DATA_PARALLEL

template <typename Float, ndorder order>
sycl::event compute_inversed_l2_norms(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 1>& out,
const event_vector& deps = {});

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 1>, sycl::event> compute_inversed_l2_norms(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps = {},
const sycl::usm::alloc& alloc = sycl::usm::alloc::device);

template <typename Float, ndorder order1, ndorder order2>
sycl::event compute_correlation_inner_product(sycl::queue& q,
const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps = {});

template <typename Float>
sycl::event finalize_correlation(sycl::queue& q,
const ndview<Float, 1>& inp1,
const ndview<Float, 1>& inp2,
ndview<Float, 2>& out,
const event_vector& deps = {});

template <typename Float, ndorder order>
sycl::event compute_deviation(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 2>& out,
const event_vector& deps = {});

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 2>, sycl::event> compute_deviation(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps = {},
const sycl::usm::alloc& alloc = sycl::usm::alloc::device);

#endif

} // namespace oneapi::dal::backend::primitives
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*******************************************************************************
* Copyright contributors to the oneDAL project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/detail/profiler.hpp"

#include "oneapi/dal/backend/primitives/distance/correlation_distance_misc.hpp"
#include "oneapi/dal/backend/primitives/distance/squared_l2_distance_misc.hpp"

#include "oneapi/dal/backend/primitives/blas.hpp"
#include "oneapi/dal/backend/primitives/reduction.hpp"
#include "oneapi/dal/backend/primitives/stat/cov.hpp"

namespace oneapi::dal::backend::primitives {

template <typename Float>
inline sycl::event inverse_l2_norms(sycl::queue& q,
ndview<Float, 1>& out,
const event_vector& deps) {
ONEDAL_PROFILER_TASK(distance.inverse_l2_norms, q);

ONEDAL_ASSERT(out.has_mutable_data());
return q.submit([&](sycl::handler& h) {
h.depends_on(deps);
const auto count = out.get_count();
const auto range = make_range_1d(count);
auto* const ptr = out.get_mutable_data();
h.parallel_for(range, [=](sycl::id<1> idx) {
auto& ref = ptr[idx];
ref = sycl::rsqrt(ref);
});
});
}

template <typename Float, ndorder order>
sycl::event compute_inversed_l2_norms(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 1>& out,
const event_vector& deps) {
ONEDAL_ASSERT(inp.has_data());
ONEDAL_ASSERT(out.has_mutable_data());
auto sq_event = compute_squared_l2_norms(q, inp, out, deps);
return inverse_l2_norms(q, out, { sq_event });
}

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 1>, sycl::event> compute_inversed_l2_norms(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps,
const sycl::usm::alloc& alloc) {
const auto n_samples = inp.get_dimension(0);
auto res_array = ndarray<Float, 1>::empty(q, { n_samples }, alloc);
return { res_array, compute_inversed_l2_norms(q, inp, res_array, deps) };
}

template <typename Float>
sycl::event finalize_correlation(sycl::queue& q,
const ndview<Float, 1>& inp1,
const ndview<Float, 1>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) {
ONEDAL_PROFILER_TASK(distance.finalize_correlation, q);

ONEDAL_ASSERT(inp1.has_data());
ONEDAL_ASSERT(inp2.has_data());
ONEDAL_ASSERT(out.has_mutable_data());
const auto out_stride = out.get_leading_stride();
const auto n_samples1 = inp1.get_dimension(0);
const auto n_samples2 = inp2.get_dimension(0);
ONEDAL_ASSERT(n_samples1 <= out.get_dimension(0));
ONEDAL_ASSERT(n_samples2 <= out.get_dimension(1));
const auto* const inp1_ptr = inp1.get_data();
const auto* const inp2_ptr = inp2.get_data();
auto* const out_ptr = out.get_mutable_data();
const auto out_range = make_range_2d(n_samples1, n_samples2);
return q.submit([&](sycl::handler& h) {
h.depends_on(deps);
h.parallel_for(out_range, [=](sycl::id<2> idx) {
constexpr Float one = 1;
auto& out = *(out_ptr + out_stride * idx[0] + idx[1]);
out = one - out * inp1_ptr[idx[0]] * inp2_ptr[idx[1]];
});
});
}

template <typename Float, ndorder order1, ndorder order2>
sycl::event compute_correlation_inner_product(sycl::queue& q,
const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) {
check_inputs(inp1, inp2, out);
auto event = gemm(q, inp1, inp2.t(), out, Float(+1.0), Float(0.0), deps);
// Workaround for abort in async mode. Should be removed later.
event.wait_and_throw();
return event;
}

template <typename Float, ndorder order>
sycl::event compute_deviation(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 2>& out,
const event_vector& deps) {
ONEDAL_ASSERT(inp.has_data());
ONEDAL_ASSERT(out.has_mutable_data());
const auto n = out.get_dimension(0);
const auto p = out.get_dimension(1);
ONEDAL_ASSERT(n == inp.get_dimension(0));
ONEDAL_ASSERT(p == inp.get_dimension(1));
auto* const inp_ptr = inp.get_mutable_data();
const auto out_stride = out.get_leading_stride();
auto* const out_ptr = out.get_mutable_data();
auto out_range = make_range_2d(n, p);
auto inp_sum = ndarray<Float, 1>::empty(q, { n });
auto inp_mean = ndarray<Float, 1>::empty(q, { n });
auto sum_event = reduce_by_rows(q, inp, inp_sum, sum<Float>{}, identity<Float>{}, deps);
auto mean_event = means(q, p, inp_sum, inp_mean, { sum_event });
return q.submit([&](sycl::handler& h) {
h.depends_on({ mean_event });
auto const inp_mean_acc = inp_mean.get_data();
h.parallel_for(out_range, [=](sycl::id<2> idx) {
const auto offset = idx[0] * out_stride + idx[1];
out_ptr[offset] = inp_ptr[offset] - inp_mean_acc[idx[0]];
});
});
}

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 2>, sycl::event> compute_deviation(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps,
const sycl::usm::alloc& alloc) {
const auto n = inp.get_dimension(0);
const auto p = inp.get_dimension(1);
auto res_array = ndarray<Float, 2>::empty(q, { n, p }, alloc);
return { res_array, compute_deviation(q, inp, res_array, deps) };
}

#define INSTANTIATE(F, A, B) \
template sycl::event compute_correlation_inner_product<F, A, B>(sycl::queue&, \
const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const event_vector&);

#define INSTANTIATE_A(F, B) \
INSTANTIATE(F, ndorder::c, B) \
INSTANTIATE(F, ndorder::f, B) \
template sycl::event compute_inversed_l2_norms<F, B>(sycl::queue&, \
const ndview<F, 2, B>&, \
ndview<F, 1>&, \
const event_vector&); \
template std::tuple<ndarray<F, 1>, sycl::event> compute_inversed_l2_norms<F, B>( \
sycl::queue&, \
const ndview<F, 2, B>&, \
const event_vector&, \
const sycl::usm::alloc&); \
template sycl::event compute_deviation<F, B>(sycl::queue&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const event_vector&); \
template std::tuple<ndarray<F, 2>, sycl::event> compute_deviation<F, B>( \
sycl::queue&, \
const ndview<F, 2, B>&, \
const event_vector&, \
const sycl::usm::alloc&);

#define INSTANTIATE_F(F) \
INSTANTIATE_A(F, ndorder::c) \
INSTANTIATE_A(F, ndorder::f) \
template sycl::event finalize_correlation<F>(sycl::queue & q,\
const ndview<F, 1>&, \
const ndview<F, 1>&, \
ndview<F, 2>&, \
const event_vector&);

INSTANTIATE_F(float);
INSTANTIATE_F(double);

#undef INSTANTIATE

} // namespace oneapi::dal::backend::primitives
Loading
Loading