diff --git a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp index e0d3a6aa3a9..429b666279e 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp @@ -26,6 +26,8 @@ #include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp" +#ifdef ONEDAL_DATA_PARALLEL + namespace oneapi::dal::basic_statistics::backend { namespace bk = dal::backend; @@ -151,16 +153,21 @@ result_t finalize_compute_kernel_dense_impl::operator()(const descriptor_ const auto nobs_nd = pr::table2ndarray_1d(q, input.get_partial_n_rows()); auto rows_count_global = nobs_nd.get_data()[0]; + auto is_distributed = (comm_.get_rank_count() > 1); { ONEDAL_PROFILER_TASK(allreduce_rows_count_global); - comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); + if (is_distributed) { + comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); + } } if (res_op.test(result_options::min)) { ONEDAL_ASSERT(input.get_partial_min().get_column_count() == column_count); const auto min = pr::table2ndarray_1d(q, input.get_partial_min(), sycl::usm::alloc::device); - { comm_.allreduce(min.flatten(q, {}), spmd::reduce_op::min).wait(); } + if (is_distributed) { + comm_.allreduce(min.flatten(q, {}), spmd::reduce_op::min).wait(); + } res.set_min(homogen_table::wrap(min.flatten(q, {}), 1, column_count)); } @@ -174,27 +181,48 @@ result_t finalize_compute_kernel_dense_impl::operator()(const descriptor_ } if (res_op_partial.test(result_options::sum)) { - const auto sums_nd = + auto sums_nd = pr::table2ndarray_1d(q, input.get_partial_sum(), sycl::usm::alloc::device); - { - ONEDAL_PROFILER_TASK(allreduce_sums, q); - comm_.allreduce(sums_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); - } - const auto sums2_nd = pr::table2ndarray_1d(q, - input.get_partial_sum_squares(), - sycl::usm::alloc::device); - { - ONEDAL_PROFILER_TASK(allreduce_sums, q); - comm_.allreduce(sums2_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); - } - const auto sums2cent_nd = - pr::table2ndarray_1d(q, - input.get_partial_sum_squares_centered(), - sycl::usm::alloc::device); - { - ONEDAL_PROFILER_TASK(allreduce_sums, q); - comm_.allreduce(sums2cent_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); + auto sums2_nd = pr::table2ndarray_1d(q, + input.get_partial_sum_squares(), + sycl::usm::alloc::device); + + auto sums2cent_nd = pr::table2ndarray_1d(q, + input.get_partial_sum_squares_centered(), + sycl::usm::alloc::device); + if (is_distributed) { + auto sums_nd_copy = + pr::ndarray::empty(q, { column_count }, sycl::usm::alloc::device); + auto copy_event = copy(q, sums_nd_copy, sums_nd, {}); + copy_event.wait_and_throw(); + sums_nd = sums_nd_copy; + + { + ONEDAL_PROFILER_TASK(allreduce_sums, q); + comm_.allreduce(sums_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); + } + + auto sums2_nd_copy = + pr::ndarray::empty(q, { column_count }, sycl::usm::alloc::device); + copy_event = copy(q, sums2_nd_copy, sums2_nd, {}); + copy_event.wait_and_throw(); + sums2_nd = sums2_nd_copy; + + { + ONEDAL_PROFILER_TASK(allreduce_sums, q); + comm_.allreduce(sums2_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); + } + auto sums2cent_nd_copy = + pr::ndarray::empty(q, { column_count }, sycl::usm::alloc::device); + copy_event = copy(q, sums2cent_nd_copy, sums2cent_nd, {}); + copy_event.wait_and_throw(); + sums2cent_nd = sums2cent_nd_copy; + { + ONEDAL_PROFILER_TASK(allreduce_sums, q); + comm_.allreduce(sums2cent_nd.flatten(q, {}), spmd::reduce_op::sum).wait(); + } } + auto [result_means, result_variance, result_raw_moment, @@ -210,18 +238,20 @@ result_t finalize_compute_kernel_dense_impl::operator()(const descriptor_ if (res_op.test(result_options::sum)) { ONEDAL_ASSERT(input.get_partial_sum().get_column_count() == column_count); - res.set_sum(input.get_partial_sum()); + res.set_sum(homogen_table::wrap(sums_nd.flatten(q, { update_event }), 1, column_count)); } if (res_op.test(result_options::sum_squares)) { ONEDAL_ASSERT(input.get_partial_sum_squares().get_column_count() == column_count); - res.set_sum_squares(input.get_partial_sum_squares()); + res.set_sum_squares( + homogen_table::wrap(sums2_nd.flatten(q, { update_event }), 1, column_count)); } if (res_op.test(result_options::sum_squares_centered)) { ONEDAL_ASSERT(input.get_partial_sum_squares_centered().get_column_count() == column_count); - res.set_sum_squares_centered(input.get_partial_sum_squares_centered()); + res.set_sum_squares_centered( + homogen_table::wrap(sums2cent_nd.flatten(q, { update_event }), 1, column_count)); } if (res_op.test(result_options::mean)) { @@ -264,3 +294,5 @@ template class finalize_compute_kernel_dense_impl; template class finalize_compute_kernel_dense_impl; } // namespace oneapi::dal::basic_statistics::backend + +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp b/cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp index 892fb3e0813..cd57b7dce1e 100644 --- a/cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp +++ b/cpp/oneapi/dal/algo/basic_statistics/test/online_spmd.cpp @@ -88,8 +88,9 @@ class basic_statistics_online_spmd_test } partial_results.push_back(partial_result); } - const auto compute_result = this->finalize_compute_override(bs_desc, partial_results); - + auto compute_result = this->finalize_compute_override(bs_desc, partial_results); + base_t::check_compute_result(compute_mode, data, weights, compute_result); + compute_result = this->finalize_compute_override(bs_desc, partial_results); base_t::check_compute_result(compute_mode, data, weights, compute_result); } else { @@ -103,8 +104,9 @@ class basic_statistics_online_spmd_test } partial_results.push_back(partial_result); } - const auto compute_result = this->finalize_compute_override(bs_desc, partial_results); - + auto compute_result = this->finalize_compute_override(bs_desc, partial_results); + base_t::check_compute_result(compute_mode, data, table{}, compute_result); + compute_result = this->finalize_compute_override(bs_desc, partial_results); base_t::check_compute_result(compute_mode, data, table{}, compute_result); } } diff --git a/cpp/oneapi/dal/algo/covariance/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp b/cpp/oneapi/dal/algo/covariance/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp index 131b681a435..c3844b2e97e 100644 --- a/cpp/oneapi/dal/algo/covariance/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp +++ b/cpp/oneapi/dal/algo/covariance/backend/gpu/finalize_compute_kernel_dense_impl_dpc.cpp @@ -66,28 +66,38 @@ result_t finalize_compute_kernel_dense_impl::operator()(const descriptor_ const auto nobs_host = pr::table2ndarray(q, input.get_partial_n_rows()); auto rows_count_global = nobs_host.get_data()[0]; - { - ONEDAL_PROFILER_TASK(allreduce_rows_count_global); - comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); - } - - ONEDAL_ASSERT(rows_count_global > 0); + auto sums = pr::table2ndarray_1d(q, input.get_partial_sum(), sycl::usm::alloc::device); + auto xtx = + pr::table2ndarray(q, input.get_partial_crossproduct(), sycl::usm::alloc::device); - const auto sums = - pr::table2ndarray_1d(q, input.get_partial_sum(), sycl::usm::alloc::device); + if (comm_.get_rank_count() > 1) { + { + ONEDAL_PROFILER_TASK(allreduce_rows_count_global); + comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); + } + auto sums_copy = + pr::ndarray::empty(q, { column_count }, sycl::usm::alloc::device); + auto copy_event = copy(q, sums_copy, sums, {}); + copy_event.wait_and_throw(); + sums = sums_copy; + { + ONEDAL_PROFILER_TASK(allreduce_sums, q); + comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait(); + } - { - ONEDAL_PROFILER_TASK(allreduce_sums, q); - comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait(); + auto xtx_copy = pr::ndarray::empty(q, + { column_count, column_count }, + sycl::usm::alloc::device); + copy_event = copy(q, xtx_copy, xtx, {}); + copy_event.wait_and_throw(); + xtx = xtx_copy; + { + ONEDAL_PROFILER_TASK(allreduce_xtx, q); + comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait(); + } } - const auto xtx = - pr::table2ndarray(q, input.get_partial_crossproduct(), sycl::usm::alloc::device); - - { - ONEDAL_PROFILER_TASK(allreduce_xtx, q); - comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait(); - } + ONEDAL_ASSERT(rows_count_global > 0); if (desc.get_result_options().test(result_options::cov_matrix)) { auto [cov, cov_event] = diff --git a/cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp b/cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp index 4a480869ab0..cd28a24a025 100644 --- a/cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp +++ b/cpp/oneapi/dal/algo/covariance/test/online_spmd.cpp @@ -79,8 +79,9 @@ class covariance_online_spmd_test } partial_results.push_back(partial_result); } - const auto compute_result = this->finalize_compute_override(cov_desc, partial_results); - + auto compute_result = this->finalize_compute_override(cov_desc, partial_results); + base_t::check_compute_result(cov_desc, data, compute_result); + compute_result = this->finalize_compute_override(cov_desc, partial_results); base_t::check_compute_result(cov_desc, data, compute_result); } diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp index c470f45403e..f6c6dd54091 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp @@ -20,6 +20,8 @@ #include "oneapi/dal/backend/primitives/lapack.hpp" +#ifdef ONEDAL_DATA_PARALLEL + namespace oneapi::dal::linear_regression::backend { namespace be = dal::backend; @@ -47,25 +49,32 @@ train_result finalize_train_kernel_norm_eq_impl::operator()( const auto feature_count = ext_feature_count - compute_intercept; const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count }; - - const auto xtx_nd = - pr::table2ndarray(q, input.get_partial_xtx(), sycl::usm::alloc::device); - const auto xty_nd = pr::table2ndarray(q, - input.get_partial_xty(), - sycl::usm::alloc::device); - const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 }; + auto xtx_nd = pr::table2ndarray(q, input.get_partial_xtx(), sycl::usm::alloc::device); + auto xty_nd = pr::table2ndarray(q, + input.get_partial_xty(), + sycl::usm::alloc::device); + const auto betas_size = check_mul_overflow(response_count, feature_count + 1); auto betas_arr = array::zeros(q, betas_size, alloc); if (comm_.get_rank_count() > 1) { + auto xtx_nd_copy = pr::ndarray::empty(q, xtx_shape, sycl::usm::alloc::device); + auto copy_event = copy(q, xtx_nd_copy, xtx_nd, {}); + copy_event.wait_and_throw(); + xtx_nd = xtx_nd_copy; { ONEDAL_PROFILER_TASK(xtx_allreduce); auto xtx_arr = dal::array::wrap(q, xtx_nd.get_mutable_data(), xtx_nd.get_count()); comm_.allreduce(xtx_arr).wait(); } + auto xty_nd_copy = + pr::ndarray::empty(q, betas_shape, sycl::usm::alloc::device); + copy_event = copy(q, xty_nd_copy, xty_nd, {}); + copy_event.wait_and_throw(); + xty_nd = xty_nd_copy; { ONEDAL_PROFILER_TASK(xty_allreduce); auto xty_arr = @@ -125,3 +134,5 @@ template class finalize_train_kernel_norm_eq_impl; template class finalize_train_kernel_norm_eq_impl; } // namespace oneapi::dal::linear_regression::backend + +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp b/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp index c0f7968adfc..51e1ed18745 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp @@ -94,7 +94,7 @@ class lr_online_spmd_test : public lr_testfinalize_train_override(desc, partial_results); + auto train_result = this->finalize_train_override(desc, partial_results); SECTION("Checking intercept values") { if (desc.get_result_options().test(result_options::intercept)) @@ -105,6 +105,18 @@ class lr_online_spmd_test : public lr_testfinalize_train_override(desc, partial_results); + + SECTION("Checking intercept values after double finalize") { + if (desc.get_result_options().test(result_options::intercept)) + base_t::check_if_close(train_result.get_intercept(), base_t::bias_, tol); + } + + SECTION("Checking coefficient values after double finalize") { + if (desc.get_result_options().test(result_options::coefficients)) + base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol); + } } private: diff --git a/cpp/oneapi/dal/algo/pca/backend/gpu/finalize_train_kernel_cov_impl_dpc.cpp b/cpp/oneapi/dal/algo/pca/backend/gpu/finalize_train_kernel_cov_impl_dpc.cpp index 31f6becf309..12862ab04ba 100644 --- a/cpp/oneapi/dal/algo/pca/backend/gpu/finalize_train_kernel_cov_impl_dpc.cpp +++ b/cpp/oneapi/dal/algo/pca/backend/gpu/finalize_train_kernel_cov_impl_dpc.cpp @@ -27,6 +27,8 @@ #include "oneapi/dal/algo/pca/backend/sign_flip.hpp" #include "oneapi/dal/table/row_accessor.hpp" +#ifdef ONEDAL_DATA_PARALLEL + namespace oneapi::dal::pca::backend { namespace bk = dal::backend; @@ -57,30 +59,42 @@ result_t finalize_train_kernel_cov_impl::operator()(const descriptor_t& d const auto nobs_host = pr::table2ndarray(q, input.get_partial_n_rows()); auto rows_count_global = nobs_host.get_data()[0]; - { - ONEDAL_PROFILER_TASK(allreduce_rows_count_global); - comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); - } - - const auto sums = - pr::table2ndarray_1d(q, input.get_partial_sum(), sycl::usm::alloc::device); - - { - ONEDAL_PROFILER_TASK(allreduce_sums, q); - comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait(); + auto sums = pr::table2ndarray_1d(q, input.get_partial_sum(), sycl::usm::alloc::device); + auto xtx = + pr::table2ndarray(q, input.get_partial_crossproduct(), sycl::usm::alloc::device); + if (comm_.get_rank_count() > 1) { + { + ONEDAL_PROFILER_TASK(allreduce_rows_count_global); + comm_.allreduce(rows_count_global, spmd::reduce_op::sum).wait(); + } + auto sums_copy = + pr::ndarray::empty(q, { column_count }, sycl::usm::alloc::device); + auto copy_event = copy(q, sums_copy, sums, {}); + copy_event.wait_and_throw(); + sums = sums_copy; + + auto xtx_copy = pr::ndarray::empty(q, + { column_count, column_count }, + sycl::usm::alloc::device); + copy_event = copy(q, xtx_copy, xtx, {}); + copy_event.wait_and_throw(); + xtx = xtx_copy; + + { + ONEDAL_PROFILER_TASK(allreduce_sums, q); + comm_.allreduce(sums.flatten(q, {}), spmd::reduce_op::sum).wait(); + } + + { + ONEDAL_PROFILER_TASK(allreduce_xtx, q); + comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait(); + } } if (desc.get_result_options().test(result_options::means)) { auto [means, means_event] = compute_means(q, sums, rows_count_global, {}); result.set_means(homogen_table::wrap(means.flatten(q, { means_event }), 1, column_count)); } - - const auto xtx = - pr::table2ndarray(q, input.get_partial_crossproduct(), sycl::usm::alloc::device); - { - ONEDAL_PROFILER_TASK(allreduce_xtx, q); - comm_.allreduce(xtx.flatten(q, {}), spmd::reduce_op::sum).wait(); - } auto [cov, cov_event] = compute_covariance(q, rows_count_global, xtx, sums, {}); auto [vars, vars_event] = compute_variances(q, cov, { cov_event }); @@ -144,3 +158,5 @@ template class finalize_train_kernel_cov_impl; template class finalize_train_kernel_cov_impl; } // namespace oneapi::dal::pca::backend + +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/pca/test/online_spmd.cpp b/cpp/oneapi/dal/algo/pca/test/online_spmd.cpp index a75d288ccc8..82e19c667cc 100644 --- a/cpp/oneapi/dal/algo/pca/test/online_spmd.cpp +++ b/cpp/oneapi/dal/algo/pca/test/online_spmd.cpp @@ -79,8 +79,10 @@ class pca_online_spmd_test : public pca_testfinalize_train_override(pca_desc, partial_results); + auto train_result = this->finalize_train_override(pca_desc, partial_results); + base_t::check_train_result(pca_desc, data_fr, train_result); + train_result = this->finalize_train_override(pca_desc, partial_results); base_t::check_train_result(pca_desc, data_fr, train_result); }