Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix online SPMD algorithms finalize call #2882

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::basic_statistics::backend {

namespace bk = dal::backend;
Expand Down Expand Up @@ -174,23 +176,44 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_
}

if (res_op_partial.test(result_options::sum)) {
const auto sums_nd =
auto sums_nd =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto sums_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_nd_copy, sums_nd, {});
copy_event.wait_and_throw();
sums_nd = sums_nd_copy;
}

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
const auto sums2_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares(),
sycl::usm::alloc::device);
auto sums2_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares(),
sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto sums2_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums2_nd_copy, sums2_nd, {});
copy_event.wait_and_throw();
sums2_nd = sums2_nd_copy;
}
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
}
const auto sums2cent_nd =
pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares_centered(),
sycl::usm::alloc::device);
auto sums2cent_nd = pr::table2ndarray_1d<Float>(q,
input.get_partial_sum_squares_centered(),
sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto sums2cent_nd_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums2cent_nd_copy, sums2cent_nd, {});
copy_event.wait_and_throw();
sums2cent_nd = sums2cent_nd_copy;
}
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums2cent_nd.flatten(q, {}), spmd::reduce_op::sum).wait();
Expand All @@ -210,18 +233,20 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_

if (res_op.test(result_options::sum)) {
ONEDAL_ASSERT(input.get_partial_sum().get_column_count() == column_count);
res.set_sum(input.get_partial_sum());
res.set_sum(homogen_table::wrap(sums_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::sum_squares)) {
ONEDAL_ASSERT(input.get_partial_sum_squares().get_column_count() == column_count);
res.set_sum_squares(input.get_partial_sum_squares());
res.set_sum_squares(
homogen_table::wrap(sums2_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::sum_squares_centered)) {
ONEDAL_ASSERT(input.get_partial_sum_squares_centered().get_column_count() ==
column_count);
res.set_sum_squares_centered(input.get_partial_sum_squares_centered());
res.set_sum_squares_centered(
homogen_table::wrap(sums2cent_nd.flatten(q, { update_event }), 1, column_count));
}

if (res_op.test(result_options::mean)) {
Expand Down Expand Up @@ -264,3 +289,5 @@ template class finalize_compute_kernel_dense_impl<float>;
template class finalize_compute_kernel_dense_impl<double>;

} // namespace oneapi::dal::basic_statistics::backend

#endif // ONEDAL_DATA_PARALLEL
10 changes: 6 additions & 4 deletions cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ class basic_statistics_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(bs_desc, partial_results);

auto compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, weights, compute_result);
compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, weights, compute_result);
}
else {
Expand All @@ -103,8 +104,9 @@ class basic_statistics_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(bs_desc, partial_results);

auto compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, table{}, compute_result);
compute_result = this->finalize_compute_override(bs_desc, partial_results);
base_t::check_compute_result(compute_mode, data, table{}, compute_result);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,29 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_

ONEDAL_ASSERT(rows_count_global > 0);

const auto sums =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);

auto sums = pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto sums_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_copy, sums, {});
copy_event.wait_and_throw();
sums = sums_copy;
}
{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait();
}

const auto xtx =
auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto xtx_copy = pr::ndarray<Float, 2>::empty(q,
{ column_count, column_count },
sycl::usm::alloc::device);
auto copy_event = copy(q, xtx_copy, xtx, {});
copy_event.wait_and_throw();
xtx = xtx_copy;
}

{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
Expand Down
5 changes: 3 additions & 2 deletions cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ class covariance_online_spmd_test
}
partial_results.push_back(partial_result);
}
const auto compute_result = this->finalize_compute_override(cov_desc, partial_results);

auto compute_result = this->finalize_compute_override(cov_desc, partial_results);
base_t::check_compute_result(cov_desc, data, compute_result);
compute_result = this->finalize_compute_override(cov_desc, partial_results);
base_t::check_compute_result(cov_desc, data, compute_result);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include "oneapi/dal/backend/primitives/lapack.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::linear_regression::backend {

namespace be = dal::backend;
Expand Down Expand Up @@ -47,15 +49,26 @@ train_result<Task> finalize_train_kernel_norm_eq_impl<Float, Task>::operator()(
const auto feature_count = ext_feature_count - compute_intercept;

const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count };

const auto xtx_nd =
pr::table2ndarray<Float>(q, input.get_partial_xtx(), sycl::usm::alloc::device);
const auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(q,
input.get_partial_xty(),
sycl::usm::alloc::device);

const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 };

auto xtx_nd = pr::table2ndarray<Float>(q, input.get_partial_xtx(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto xtx_nd_copy = pr::ndarray<Float, 2>::empty(q, xtx_shape, sycl::usm::alloc::device);
auto copy_event = copy(q, xtx_nd_copy, xtx_nd, {});
copy_event.wait_and_throw();
xtx_nd = xtx_nd_copy;
}
auto xty_nd = pr::table2ndarray<Float, pr::ndorder::f>(q,
input.get_partial_xty(),
sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto xty_nd_copy =
pr::ndarray<Float, 2, pr::ndorder::f>::empty(q, betas_shape, sycl::usm::alloc::device);
auto copy_event = copy(q, xty_nd_copy, xty_nd, {});
copy_event.wait_and_throw();
xty_nd = xty_nd_copy;
}

const auto betas_size = check_mul_overflow(response_count, feature_count + 1);
auto betas_arr = array<Float>::zeros(q, betas_size, alloc);

Expand Down Expand Up @@ -125,3 +138,5 @@ template class finalize_train_kernel_norm_eq_impl<float, task::regression>;
template class finalize_train_kernel_norm_eq_impl<double, task::regression>;

} // namespace oneapi::dal::linear_regression::backend

#endif // ONEDAL_DATA_PARALLEL
14 changes: 13 additions & 1 deletion cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class lr_online_spmd_test : public lr_test<TestType, lr_online_spmd_test<TestTyp
partial_results.push_back(partial_result);
}

const auto train_result = this->finalize_train_override(desc, partial_results);
auto train_result = this->finalize_train_override(desc, partial_results);

SECTION("Checking intercept values") {
if (desc.get_result_options().test(result_options::intercept))
Expand All @@ -105,6 +105,18 @@ class lr_online_spmd_test : public lr_test<TestType, lr_online_spmd_test<TestTyp
if (desc.get_result_options().test(result_options::coefficients))
base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol);
}

train_result = this->finalize_train_override(desc, partial_results);

SECTION("Checking intercept values after double finalize") {
if (desc.get_result_options().test(result_options::intercept))
base_t::check_if_close(train_result.get_intercept(), base_t::bias_, tol);
}

SECTION("Checking coefficient values after double finalize") {
if (desc.get_result_options().test(result_options::coefficients))
base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol);
}
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "oneapi/dal/algo/pca/backend/sign_flip.hpp"
#include "oneapi/dal/table/row_accessor.hpp"

#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::pca::backend {

namespace bk = dal::backend;
Expand Down Expand Up @@ -62,8 +64,14 @@ result_t finalize_train_kernel_cov_impl<Float>::operator()(const descriptor_t& d
comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait();
}

const auto sums =
pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
auto sums = pr::table2ndarray_1d<Float>(q, input.get_partial_sum(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto sums_copy =
pr::ndarray<Float, 1>::empty(q, { column_count }, sycl::usm::alloc::device);
auto copy_event = copy(q, sums_copy, sums, {});
copy_event.wait_and_throw();
sums = sums_copy;
}

{
ONEDAL_PROFILER_TASK(allreduce_sums, q);
Expand All @@ -75,8 +83,16 @@ result_t finalize_train_kernel_cov_impl<Float>::operator()(const descriptor_t& d
result.set_means(homogen_table::wrap(means.flatten(q, { means_event }), 1, column_count));
}

const auto xtx =
auto xtx =
pr::table2ndarray<Float>(q, input.get_partial_crossproduct(), sycl::usm::alloc::device);
if (comm_.get_rank_count() > 1) {
auto xtx_copy = pr::ndarray<Float, 2>::empty(q,
{ column_count, column_count },
sycl::usm::alloc::device);
auto copy_event = copy(q, xtx_copy, xtx, {});
copy_event.wait_and_throw();
xtx = xtx_copy;
}
{
ONEDAL_PROFILER_TASK(allreduce_xtx, q);
comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait();
Expand Down Expand Up @@ -144,3 +160,5 @@ template class finalize_train_kernel_cov_impl<float>;
template class finalize_train_kernel_cov_impl<double>;

} // namespace oneapi::dal::pca::backend

#endif // ONEDAL_DATA_PARALLEL
4 changes: 3 additions & 1 deletion cpp/oneapi/dal/algo/pca/test/online_spmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ class pca_online_spmd_test : public pca_test<TestType, pca_online_spmd_test<Test
}
partial_results.push_back(partial_result);
}
const auto train_result = this->finalize_train_override(pca_desc, partial_results);
auto train_result = this->finalize_train_override(pca_desc, partial_results);
base_t::check_train_result(pca_desc, data_fr, train_result);

train_result = this->finalize_train_override(pca_desc, partial_results);
base_t::check_train_result(pca_desc, data_fr, train_result);
}

Expand Down
Loading