Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding correlation distance metric in oneDAL primitives #3059

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4fade17
Adding correlation distance metric
Feb 3, 2025
bb61721
Update metrics.hpp
richardnorth3 Feb 3, 2025
cd7c506
Update metrics.hpp with correlation_metric for GPU
richardnorth3 Feb 14, 2025
2039675
Add correlation distance test
richardnorth3 Feb 20, 2025
b4fd66a
Add supporting files for correlation distance
richardnorth3 Feb 24, 2025
22f03a8
Update metrics.hpp
richardnorth3 Feb 24, 2025
93a1c98
Update distance.hpp with correlation prototypes
richardnorth3 Feb 24, 2025
aec5d45
Merge branch 'dev/correlation-metric' of https://github.com/richardno…
richardnorth3 Feb 24, 2025
a53164d
Update metrics.hpp
richardnorth3 Feb 24, 2025
7ade161
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
c56b209
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
0b45447
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
9b09822
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
cc52792
Update bazel BUILD file with cov deps
richardnorth3 Feb 24, 2025
8232af4
Update bazel BUILD file
richardnorth3 Feb 24, 2025
8670a5a
Update bazel BUILD file
richardnorth3 Feb 24, 2025
7d88579
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
4740ad6
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
069a787
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
325e01b
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
01c4b72
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
f888e7a
Update correlation distance test and header
Feb 25, 2025
71609fc
Update correlation distance test
richardnorth3 Feb 25, 2025
8008137
Update correlation_distance_misc.hpp
richardnorth3 Feb 25, 2025
dbbbc3e
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
a862a74
Update correlation distance test
Feb 25, 2025
68e8c4f
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/oneapi/dal/backend/primitives/distance/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dal_module(
"@onedal//cpp/oneapi/dal/backend/primitives:blas",
"@onedal//cpp/oneapi/dal/backend/primitives:common",
"@onedal//cpp/oneapi/dal/backend/primitives:reduction",
"@onedal//cpp/oneapi/dal/backend/primitives:stat",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/backend/primitives/distance/distance.hpp"
#include "oneapi/dal/backend/primitives/distance/correlation_distance_misc.hpp"

#include "oneapi/dal/backend/primitives/blas.hpp"
#include "oneapi/dal/backend/primitives/reduction/reduction.hpp"
#include "oneapi/dal/backend/primitives/ndarray.hpp"
#include "oneapi/dal/backend/primitives/stat/cov.hpp"


#include <sycl/sycl.hpp>

namespace oneapi::dal::backend::primitives {

template <typename Float>
template <ndorder order>
auto distance<Float, correlation_metric<Float>>::get_inversed_norms(const ndview<Float, 2, order>& inp,
const event_vector& deps) const
-> inv_norms_res_t {
return compute_inversed_l2_norms(q_, inp, deps);
}

template <typename Float>
template <ndorder order1, ndorder order2>
sycl::event distance<Float, correlation_metric<Float>>::operator()(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const ndview<Float, 1>& inp1_norms,
const ndview<Float, 1>& inp2_norms,
const event_vector& deps) const {
auto ip_event = compute_correlation_inner_product(q_, inp1, inp2, out, deps);
return finalize_correlation(q_, inp1_norms, inp2_norms, out, { ip_event });
}

template <typename Float>
template <ndorder order1, ndorder order2>
sycl::event distance<Float, correlation_metric<Float>>::operator()(
const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) const {
const std::int64_t n = inp1.get_dimension(0);
const std::int64_t p = inp1.get_dimension(1);
auto inp1_sum = ndarray<Float, 1>::empty(q_, { n });
auto inp2_sum = ndarray<Float, 1>::empty(q_, { n });
auto inp1_mean = ndarray<Float, 1>::empty(q_, { n });
auto inp2_mean = ndarray<Float, 1>::empty(q_, { n });
sycl::event evt1 = reduce_by_rows(q_, inp1, inp1_sum, {}, {}, deps);
sycl::event evt2 = reduce_by_rows(q_, inp2, inp2_sum, {}, {}, { evt1 });
sycl::event evt3 = means(q_, p, inp1_sum, inp1_mean, { evt2 });
sycl::event evt4 = means(q_, p, inp2_sum, inp2_mean, { evt3 });
auto centered_inp1 = ndarray<Float, 2>::empty(q_, { n, p });
auto centered_inp2 = ndarray<Float, 2>::empty(q_, { n, p });
sycl::event evt5 = q_.submit([&](sycl::handler& h) {
h.depends_on({ evt4 });
auto inp1_acc = inp1.get_data();
auto inp2_acc = inp2.get_data();
auto inp1_mean_acc = inp1_mean.get_data();
auto inp2_mean_acc = inp2_mean.get_data();
auto centered1_acc = centered_inp1.get_mutable_data();
auto centered2_acc = centered_inp2.get_mutable_data();

h.parallel_for(sycl::range<2>(n, p), [=](sycl::id<2> idx) {
const std::int64_t row = idx[0];
const std::int64_t col = idx[1];
centered1_acc(row, col) = inp1_acc(row, col) - inp1_mean_acc[row];
centered2_acc(row, col) = inp2_acc(row, col) - inp2_mean_acc[row];
});
});
evt5.wait();

auto [inv_norms1_array, inv_norms1_event] = get_inversed_norms(centered_inp1, { evt5 });
auto [inv_norms2_array, inv_norms2_event] = get_inversed_norms(centered_inp2, { evt5 });
return this->operator()(inp1,
inp2,
out,
inv_norms1_array,
inv_norms2_array,
{ inv_norms1_event, inv_norms2_event });
}

#define INSTANTIATE(F, A, B) \
template sycl::event distance<F, correlation_metric<F>>::operator()(const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const ndview<F, 1>&, \
const ndview<F, 1>&, \
const event_vector&) const; \
template sycl::event distance<F, correlation_metric<F>>::operator()(const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const event_vector&) const;

#define INSTANTIATE_B(F, A) \
INSTANTIATE(F, A, ndorder::c) \
INSTANTIATE(F, A, ndorder::f) \
template std::tuple<ndarray<F, 1>, sycl::event> \
distance<F, correlation_metric<F>>::get_inversed_norms(const ndview<F, 2, A>& inp, \
const event_vector& deps) const;

#define INSTANTIATE_F(F) \
INSTANTIATE_B(F, ndorder::c) \
INSTANTIATE_B(F, ndorder::f) \
template class distance<F, squared_l2_metric<F>>;

INSTANTIATE_F(float);
INSTANTIATE_F(double);

#undef INSTANTIATE

} // namespace oneapi::dal::backend::primitives
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/backend/primitives/common.hpp"
#include "oneapi/dal/backend/primitives/ndarray.hpp"

#include "oneapi/dal/backend/primitives/distance/distance.hpp"

namespace oneapi::dal::backend::primitives {

#ifdef ONEDAL_DATA_PARALLEL

template <typename Float, ndorder order>
sycl::event compute_inversed_l2_norms(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 1>& out,
const event_vector& deps = {});

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 1>, sycl::event> compute_inversed_l2_norms(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps = {},
const sycl::usm::alloc& alloc = sycl::usm::alloc::device);

template <typename Float, ndorder order1, ndorder order2>
sycl::event compute_correlation_inner_product(sycl::queue& q,
const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps = {});

template <typename Float>
sycl::event finalize_correlation(sycl::queue& q,
const ndview<Float, 1>& inp1,
const ndview<Float, 1>& inp2,
ndview<Float, 2>& out,
const event_vector& deps = {});

#endif

} // namespace oneapi::dal::backend::primitives
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "oneapi/dal/detail/profiler.hpp"

#include "oneapi/dal/backend/primitives/distance/cosine_distance_misc.hpp"
#include "oneapi/dal/backend/primitives/distance/squared_l2_distance_misc.hpp"

#include "oneapi/dal/backend/primitives/blas.hpp"
#include "oneapi/dal/backend/primitives/reduction.hpp"

namespace oneapi::dal::backend::primitives {

template <typename Float>
inline sycl::event inverse_l2_norms(sycl::queue& q,
ndview<Float, 1>& out,
const event_vector& deps) {
ONEDAL_PROFILER_TASK(distance.inverse_l2_norms, q);

ONEDAL_ASSERT(out.has_mutable_data());
return q.submit([&](sycl::handler& h) {
h.depends_on(deps);
const auto count = out.get_count();
const auto range = make_range_1d(count);
auto* const ptr = out.get_mutable_data();
h.parallel_for(range, [=](sycl::id<1> idx) {
auto& ref = ptr[idx];
ref = sycl::rsqrt(ref);
});
});
}

template <typename Float, ndorder order>
sycl::event compute_inversed_l2_norms(sycl::queue& q,
const ndview<Float, 2, order>& inp,
ndview<Float, 1>& out,
const event_vector& deps) {
ONEDAL_ASSERT(inp.has_data());
ONEDAL_ASSERT(out.has_mutable_data());
auto sq_event = compute_squared_l2_norms(q, inp, out, deps);
return inverse_l2_norms(q, out, { sq_event });
}

template <typename Float, ndorder order>
std::tuple<ndarray<Float, 1>, sycl::event> compute_inversed_l2_norms(
sycl::queue& q,
const ndview<Float, 2, order>& inp,
const event_vector& deps,
const sycl::usm::alloc& alloc) {
const auto n_samples = inp.get_dimension(0);
auto res_array = ndarray<Float, 1>::empty(q, { n_samples }, alloc);
return { res_array, compute_inversed_l2_norms(q, inp, res_array, deps) };
}

template <typename Float>
sycl::event finalize_cosine(sycl::queue& q,
const ndview<Float, 1>& inp1,
const ndview<Float, 1>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) {
ONEDAL_PROFILER_TASK(distance.finalize_cosine, q);

ONEDAL_ASSERT(inp1.has_data());
ONEDAL_ASSERT(inp2.has_data());
ONEDAL_ASSERT(out.has_mutable_data());
const auto out_stride = out.get_leading_stride();
const auto n_samples1 = inp1.get_dimension(0);
const auto n_samples2 = inp2.get_dimension(0);
ONEDAL_ASSERT(n_samples1 <= out.get_dimension(0));
ONEDAL_ASSERT(n_samples2 <= out.get_dimension(1));
const auto* const inp1_ptr = inp1.get_data();
const auto* const inp2_ptr = inp2.get_data();
auto* const out_ptr = out.get_mutable_data();
const auto out_range = make_range_2d(n_samples1, n_samples2);
return q.submit([&](sycl::handler& h) {
h.depends_on(deps);
h.parallel_for(out_range, [=](sycl::id<2> idx) {
constexpr Float one = 1;
auto& out = *(out_ptr + out_stride * idx[0] + idx[1]);
out = one - out * inp1_ptr[idx[0]] * inp2_ptr[idx[1]];
});
});
}

template <typename Float, ndorder order1, ndorder order2>
sycl::event compute_cosine_inner_product(sycl::queue& q,
const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps) {
check_inputs(inp1, inp2, out);
auto event = gemm(q, inp1, inp2.t(), out, Float(+1.0), Float(0.0), deps);
// Workaround for abort in async mode. Should be removed later.
event.wait_and_throw();
return event;
}

#define INSTANTIATE(F, A, B) \
template sycl::event compute_cosine_inner_product<F, A, B>(sycl::queue&, \
const ndview<F, 2, A>&, \
const ndview<F, 2, B>&, \
ndview<F, 2>&, \
const event_vector&);

#define INSTANTIATE_A(F, B) \
INSTANTIATE(F, ndorder::c, B) \
INSTANTIATE(F, ndorder::f, B) \
template sycl::event compute_inversed_l2_norms<F, B>(sycl::queue&, \
const ndview<F, 2, B>&, \
ndview<F, 1>&, \
const event_vector&); \
template std::tuple<ndarray<F, 1>, sycl::event> compute_inversed_l2_norms<F, B>( \
sycl::queue&, \
const ndview<F, 2, B>&, \
const event_vector&, \
const sycl::usm::alloc&);

#define INSTANTIATE_F(F) \
INSTANTIATE_A(F, ndorder::c) \
INSTANTIATE_A(F, ndorder::f) \
template sycl::event finalize_cosine<F>(sycl::queue & q, \
const ndview<F, 1>&, \
const ndview<F, 1>&, \
ndview<F, 2>&, \
const event_vector&);

INSTANTIATE_F(float);
INSTANTIATE_F(double);

#undef INSTANTIATE

} // namespace oneapi::dal::backend::primitives
33 changes: 33 additions & 0 deletions cpp/oneapi/dal/backend/primitives/distance/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,36 @@ class distance<Float, cosine_metric<Float>> {
sycl::queue& q_;
};

template <typename Float>
class distance<Float, correlation_metric<Float>> {
public:
distance(sycl::queue& q) : q_{ q } {};

template <ndorder order1, ndorder order2>
sycl::event operator()(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const event_vector& deps = {}) const;

template <ndorder order1, ndorder order2>
sycl::event operator()(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
ndview<Float, 2>& out,
const ndview<Float, 1>& inp1_norms,
const ndview<Float, 1>& inp2_norms,
const event_vector& deps = {}) const;

protected:
using inv_norms_res_t = std::tuple<ndarray<Float, 1>, sycl::event>;

template <ndorder order>
inv_norms_res_t get_inversed_norms(const ndview<Float, 2, order>& inp,
const event_vector& deps = {}) const;

private:
sycl::queue& q_;
};

template <typename Float>
using lp_distance = distance<Float, lp_metric<Float>>;

Expand All @@ -115,6 +145,9 @@ using cosine_distance = distance<Float, cosine_metric<Float>>;
template <typename Float>
using chebyshev_distance = distance<Float, chebyshev_metric<Float>>;

template <typename Float>
using correlation_distance = distance<Float, correlation_metric<Float>>;

template <typename Float, ndorder order1, ndorder order2>
void check_inputs(const ndview<Float, 2, order1>& inp1,
const ndview<Float, 2, order2>& inp2,
Expand Down
Loading
Loading