Skip to content

Commit

Permalink
Improve SYCL kernels in CSR table (#2794)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vika-F authored May 24, 2024
1 parent 99570f8 commit f06b35a
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions cpp/oneapi/dal/table/backend/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ bool is_sorted(sycl::queue& queue,
// number of pairs of the subsequent elements in the data array that are sorted in desccending order,
// i.e. for which data[i] > data[i + 1] is true.
std::int64_t count_descending_pairs = 0L;

sycl::buffer<std::int64_t, 1> count_buf(&count_descending_pairs, sycl::range<1>(1));

// count the number of pairs of the subsequent elements in the data array that are sorted
Expand All @@ -418,10 +419,9 @@ bool is_sorted(sycl::queue& queue,
auto count_descending_reduction =
sycl::reduction(count_buf, cgh, sycl::ext::oneapi::plus<std::int64_t>());

cgh.parallel_for(sycl::nd_range<1>{ count - 1, 1 },
cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast<std::size_t>(count - 1) },
count_descending_reduction,
[=](sycl::nd_item<1> idx, auto& count_descending) {
const auto i = idx.get_global_id(0);
[=](sycl::id<1> i, auto& count_descending) {
if (data[i] > data[i + 1])
count_descending.combine(1);
});
Expand Down Expand Up @@ -485,39 +485,29 @@ out_of_bound_type check_bounds(const array<T>& arr,
sycl::buffer<std::int64_t, 1> count_lt_buf(&count_lt_min, sycl::range<1>(1));
sycl::buffer<std::int64_t, 1> count_gt_buf(&count_gt_max, sycl::range<1>(1));

// count the number of elements which are less than min_vaule using sycl::reduction
auto event_count_lt_min = queue.submit([&](sycl::handler& cgh) {
cgh.depends_on(dependencies);
auto count_lt_reduction =
sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus<std::int64_t>());

cgh.parallel_for(sycl::nd_range<1>{ count, 1 },
count_lt_reduction,
[=](sycl::nd_item<1> idx, auto& count_lt) {
const auto i = idx.get_global_id(0);
if (data[i] < min_value) {
count_lt.combine(1);
}
});
});

// count the number of elements which are greater than max_vaule using sycl::reduction
auto event_count_gt_max = queue.submit([&](sycl::handler& cgh) {
cgh.depends_on(dependencies);
auto count_gt_reduction =
sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus<std::int64_t>());

cgh.parallel_for(sycl::nd_range<1>{ count, 1 },
count_gt_reduction,
[=](sycl::nd_item<1> idx, auto& count_gt) {
const auto i = idx.get_global_id(0);
if (data[i] > max_value) {
count_gt.combine(1);
}
});
});

sycl::event::wait_and_throw({ event_count_lt_min, event_count_gt_max });
// count the number of elements which are less than min_vaule and
// the the number of elements which are greater than max_value using sycl::reduction
queue
.submit([&](sycl::handler& cgh) {
cgh.depends_on(dependencies);
auto count_lt_reduction =
sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus<std::int64_t>());
auto count_gt_reduction =
sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus<std::int64_t>());

cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast<std::size_t>(count) },
count_lt_reduction,
count_gt_reduction,
[=](sycl::id<1> i, auto& count_lt, auto& count_gt) {
if (data[i] < min_value) {
count_lt.combine(1);
}
if (data[i] > max_value) {
count_gt.combine(1);
}
});
})
.wait_and_throw();

out_of_bound_type result{ out_of_bound_type::within_bounds };
if (count_lt_min > 0)
Expand Down

0 comments on commit f06b35a

Please sign in to comment.