From f06b35aa99ddb1f06de7f7b36ccffeb7b3d66ebc Mon Sep 17 00:00:00 2001 From: Victoriya Fedotova Date: Fri, 24 May 2024 13:09:19 +0200 Subject: [PATCH] Improve SYCL kernels in CSR table (#2794) --- cpp/oneapi/dal/table/backend/csr_kernels.cpp | 62 ++++++++------------ 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 58a73b74549..8e5aef236b2 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -408,6 +408,7 @@ bool is_sorted(sycl::queue& queue, // number of pairs of the subsequent elements in the data array that are sorted in desccending order, // i.e. for which data[i] > data[i + 1] is true. std::int64_t count_descending_pairs = 0L; + sycl::buffer count_buf(&count_descending_pairs, sycl::range<1>(1)); // count the number of pairs of the subsequent elements in the data array that are sorted @@ -418,10 +419,9 @@ bool is_sorted(sycl::queue& queue, auto count_descending_reduction = sycl::reduction(count_buf, cgh, sycl::ext::oneapi::plus()); - cgh.parallel_for(sycl::nd_range<1>{ count - 1, 1 }, + cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count - 1) }, count_descending_reduction, - [=](sycl::nd_item<1> idx, auto& count_descending) { - const auto i = idx.get_global_id(0); + [=](sycl::id<1> i, auto& count_descending) { if (data[i] > data[i + 1]) count_descending.combine(1); }); @@ -485,39 +485,29 @@ out_of_bound_type check_bounds(const array& arr, sycl::buffer count_lt_buf(&count_lt_min, sycl::range<1>(1)); sycl::buffer count_gt_buf(&count_gt_max, sycl::range<1>(1)); - // count the number of elements which are less than min_vaule using sycl::reduction - auto event_count_lt_min = queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - auto count_lt_reduction = - sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus()); - - cgh.parallel_for(sycl::nd_range<1>{ count, 1 }, - count_lt_reduction, - [=](sycl::nd_item<1> idx, auto& count_lt) { - const auto i = idx.get_global_id(0); - if (data[i] < min_value) { - count_lt.combine(1); - } - }); - }); - - // count the number of elements which are greater than max_vaule using sycl::reduction - auto event_count_gt_max = queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - auto count_gt_reduction = - sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus()); - - cgh.parallel_for(sycl::nd_range<1>{ count, 1 }, - count_gt_reduction, - [=](sycl::nd_item<1> idx, auto& count_gt) { - const auto i = idx.get_global_id(0); - if (data[i] > max_value) { - count_gt.combine(1); - } - }); - }); - - sycl::event::wait_and_throw({ event_count_lt_min, event_count_gt_max }); + // count the number of elements which are less than min_vaule and + // the the number of elements which are greater than max_value using sycl::reduction + queue + .submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + auto count_lt_reduction = + sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus()); + auto count_gt_reduction = + sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus()); + + cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count) }, + count_lt_reduction, + count_gt_reduction, + [=](sycl::id<1> i, auto& count_lt, auto& count_gt) { + if (data[i] < min_value) { + count_lt.combine(1); + } + if (data[i] > max_value) { + count_gt.combine(1); + } + }); + }) + .wait_and_throw(); out_of_bound_type result{ out_of_bound_type::within_bounds }; if (count_lt_min > 0)