Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding correlation distance metric in oneDAL primitives #3059

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4fade17
Adding correlation distance metric
Feb 3, 2025
bb61721
Update metrics.hpp
richardnorth3 Feb 3, 2025
cd7c506
Update metrics.hpp with correlation_metric for GPU
richardnorth3 Feb 14, 2025
2039675
Add correlation distance test
richardnorth3 Feb 20, 2025
b4fd66a
Add supporting files for correlation distance
richardnorth3 Feb 24, 2025
22f03a8
Update metrics.hpp
richardnorth3 Feb 24, 2025
93a1c98
Update distance.hpp with correlation prototypes
richardnorth3 Feb 24, 2025
aec5d45
Merge branch 'dev/correlation-metric' of https://github.com/richardno…
richardnorth3 Feb 24, 2025
a53164d
Update metrics.hpp
richardnorth3 Feb 24, 2025
7ade161
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
c56b209
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
0b45447
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
9b09822
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
cc52792
Update bazel BUILD file with cov deps
richardnorth3 Feb 24, 2025
8232af4
Update bazel BUILD file
richardnorth3 Feb 24, 2025
8670a5a
Update bazel BUILD file
richardnorth3 Feb 24, 2025
7d88579
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
4740ad6
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
069a787
Update correlation_distance_dpc.cpp
richardnorth3 Feb 24, 2025
325e01b
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
01c4b72
Update cpp/oneapi/dal/backend/primitives/distance/correlation_distanc…
richardnorth3 Feb 25, 2025
f888e7a
Update correlation distance test and header
Feb 25, 2025
71609fc
Update correlation distance test
richardnorth3 Feb 25, 2025
8008137
Update correlation_distance_misc.hpp
richardnorth3 Feb 25, 2025
dbbbc3e
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
a862a74
Update correlation distance test
Feb 25, 2025
68e8c4f
Update correlation_distance_misc_dpc.cpp
richardnorth3 Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions cpp/oneapi/dal/backend/primitives/distance/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,103 @@ struct chebyshev_metric : public metric_base<Float> {
}
};

template <typename Float>
struct correlation_metric : public metric_base<Float> {
public:
correlation_metric(){}
template <typename InputIt1, typename InputIt2>
Float operator()(InputIt1 first1, InputIt1 last1, InputIt2 first2) const {
constexpr Float zero = 0;
constexpr Float one = 1;
Float ip_acc = zero;
Float n1_acc = zero;
Float n2_acc = zero;
Float n1_sum = zero;
Float n1_sum = zero;
Float count = zero;
for (auto it1 = first1, it2 = first2; it1 != last1; ++it1, it2++) {
n1_sum += *it1;
n2_sum += *it2;
++count;
}

if (count == zero)
return Float(zero);

const Float n1_mean = n1_sum / count;
const Float n2_mean = n2_sum / count;

for (auto it1 = first1, it2 = first2; it1 != last1; ++it1, ++it2) {
const Float v1 = *it1 - n1_mean;
const Float v2 = *it2 - n2_mean;
n1_acc += (v1 * v1);
n2_acc += (v2 * v2);
ip_acc += (v1 * v2);
}
const Float rsqn1 = one / std::sqrt(n1_acc);
const Float rsqn2 = one / std::sqrt(n2_acc);
return one - ip_acc * rsqn1 * rsqn2;
}

template <ndorder order>
sycl::event operator()(sycl::queue& q,
const ndview<Float, 2, order>& u,
const ndview<Float, 2, order>& v,
ndview<Float, 2>& out,
const event_vector& deps = {}) const {
const std::int64_t n = u.get_dimension(0);
auto u_sum = ndarray<Float, 1>::empty({ 1 });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this up @richardnorth3 ! https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.correlation.html should be a good reference for the means necessary. The mean doesn't need to be a single value, but a vector the size of a single sample with each being the mean value for each feature (i.e. for each row). This will then be a vector subtraction. Let me know if I am misunderstanding things about your implementation. Good progress.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey sorry about that, the definition is the mean of the 1d array, so my recommendations have been a little bit wrong.

auto v_sum = ndarray<Float, 1>::empty({ 1 });
auto u_mean = ndarray<Float, 1>::empty({ 1 });
auto v_mean = ndarray<Float, 1>::empty({ 1 });
sycl::event evt1 = reduce_by_columns(q, u, u_sum, {}, {}, deps, true);
sycl::event evt2 = reduce_by_columns(q, v, v_sum, {}, {}, { evt1 }, true);
sycl::event evt3 = means(q, n, u_sum, u_mean, { evt2 });
sycl::event evt4 = means(q, n, v_sum, v_mean, { evt3 });
auto temp = ndarray<Float, 1>::empty({ 3 });
q.fill(temp, Float(0)).wait();
sycl::event evt5 = q.submit([&](sycl::handler& h) {
h.depends_on({ evt4 });
h.parallel_for(sycl::range<1>(n), [=](sycl::id<1> idx) {
const std::int64_t i = idx[0];
const Float x = u.at(i, 0);
const Float y = v.at(i, 0);
const Float mu_x = u_mean.at(0);
const Float mu_y = v_mean.at(0);
const Float d1 = x - mu_x;
const Float d2 = y - mu_y;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space>
atomic_dot(temp.get_mutable_data()[0]);
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space>
atomic_norm1(temp.get_mutable_data()[1]);
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space>
atomic_norm2(temp.get_mutable_data()[2]);
atomic_dot.fetch_add(d1 * d2);
atomic_norm1.fetch_add(d1 * d1);
atomic_norm2.fetch_add(d2 * d2);
});
});
evt5.wait_and_throw();
std::array<Float, 3> host_temp;
q.memcpy(host_temp.data(), temp.get_mutable_data(), 3 * sizeof(Float)).wait();
const Float dot = host_temp[0];
const Float norm1 = host_temp[1];
const Float norm2 = host_temp[2];
const Float corr =
(norm1 > 0 && norm2 > 0) ? dot / (std::sqrt(norm1) * std::sqrt(norm2)) : Float(0);
const Float distance = 1 - corr;
out.get_mutable_data()[0] = distance;
return evt5;
}
};

} // namespace oneapi::dal::backend::primitives
Loading