From e48e3ae0e1312d8e58575e42b46cfeb193e5fa68 Mon Sep 17 00:00:00 2001 From: Victoriya Fedotova Date: Tue, 21 May 2024 04:47:58 -0700 Subject: [PATCH 1/4] Improve SYCL kernels in CSR table --- cpp/oneapi/dal/table/backend/csr_kernels.cpp | 63 +++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 58a73b74549..3a406acef07 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -18,6 +18,7 @@ #include "oneapi/dal/table/backend/convert.hpp" #include +#include namespace oneapi::dal::backend { @@ -408,6 +409,7 @@ bool is_sorted(sycl::queue& queue, // number of pairs of the subsequent elements in the data array that are sorted in desccending order, // i.e. for which data[i] > data[i + 1] is true. std::int64_t count_descending_pairs = 0L; + sycl::buffer count_buf(&count_descending_pairs, sycl::range<1>(1)); // count the number of pairs of the subsequent elements in the data array that are sorted @@ -418,10 +420,9 @@ bool is_sorted(sycl::queue& queue, auto count_descending_reduction = sycl::reduction(count_buf, cgh, sycl::ext::oneapi::plus()); - cgh.parallel_for(sycl::nd_range<1>{ count - 1, 1 }, + cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count - 1) }, count_descending_reduction, - [=](sycl::nd_item<1> idx, auto& count_descending) { - const auto i = idx.get_global_id(0); + [=](sycl::id<1> i, auto& count_descending) { if (data[i] > data[i + 1]) count_descending.combine(1); }); @@ -485,39 +486,29 @@ out_of_bound_type check_bounds(const array& arr, sycl::buffer count_lt_buf(&count_lt_min, sycl::range<1>(1)); sycl::buffer count_gt_buf(&count_gt_max, sycl::range<1>(1)); - // count the number of elements which are less than min_vaule using sycl::reduction - auto event_count_lt_min = queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - auto count_lt_reduction = - sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus()); - - cgh.parallel_for(sycl::nd_range<1>{ count, 1 }, - count_lt_reduction, - [=](sycl::nd_item<1> idx, auto& count_lt) { - const auto i = idx.get_global_id(0); - if (data[i] < min_value) { - count_lt.combine(1); - } - }); - }); - - // count the number of elements which are greater than max_vaule using sycl::reduction - auto event_count_gt_max = queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(dependencies); - auto count_gt_reduction = - sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus()); - - cgh.parallel_for(sycl::nd_range<1>{ count, 1 }, - count_gt_reduction, - [=](sycl::nd_item<1> idx, auto& count_gt) { - const auto i = idx.get_global_id(0); - if (data[i] > max_value) { - count_gt.combine(1); - } - }); - }); - - sycl::event::wait_and_throw({ event_count_lt_min, event_count_gt_max }); + // count the number of elements which are less than min_vaule and + // the the number of elements which are greater than max_value using sycl::reduction + queue + .submit([&](sycl::handler& cgh) { + cgh.depends_on(dependencies); + auto count_lt_reduction = + sycl::reduction(count_lt_buf, cgh, sycl::ext::oneapi::plus()); + auto count_gt_reduction = + sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus()); + + cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count) }, + count_lt_reduction, + count_gt_reduction, + [=](sycl::id<1> i, auto& count_lt, auto &count_gt) { + if (data[i] < min_value) { + count_lt.combine(1); + } + if (data[i] > max_value) { + count_gt.combine(1); + } + }); + }) + .wait_and_throw(); out_of_bound_type result{ out_of_bound_type::within_bounds }; if (count_lt_min > 0) From db43d087a7fa77c1155247b126cdb3e8a9051637 Mon Sep 17 00:00:00 2001 From: Victoriya Fedotova Date: Tue, 21 May 2024 04:52:50 -0700 Subject: [PATCH 2/4] Enable sparse primitives testing on GPU --- .../dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp | 3 +-- .../dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp index 3faac011d35..cea8206928c 100644 --- a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp @@ -32,8 +32,7 @@ TEMPLATE_LIST_TEST_M(sparse_blas_test, "ones matrix sparse CSR gemm", "[csr][gem SKIP_IF(this->get_policy().is_cpu()); // Test takes too long time if HW emulates float64 - // Temporary workaround: skip tests on architectures that do not support native float64 - SKIP_IF(!this->get_policy().has_native_float64()); + SKIP_IF(!this->not_float64_friendly()); this->generate_dimensions(); this->test_gemm(); diff --git a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp index 4695095b2a4..da256eb5554 100644 --- a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp @@ -31,8 +31,7 @@ TEMPLATE_LIST_TEST_M(sparse_blas_test, "ones matrix sparse CSR gemv", "[csr][gem SKIP_IF(this->get_policy().is_cpu()); // Test takes too long time if HW emulates float64 - // Temporary workaround: skip tests on architectures that do not support native float64 - SKIP_IF(!this->get_policy().has_native_float64()); + SKIP_IF(!this->not_float64_friendly()); this->generate_dimensions_gemv(); this->test_gemv(); From afa97ac8e9c09b52bed729bfe6e1322b6762c48f Mon Sep 17 00:00:00 2001 From: Victoriya Fedotova Date: Tue, 21 May 2024 04:55:06 -0700 Subject: [PATCH 3/4] Remove unneeded include --- cpp/oneapi/dal/table/backend/csr_kernels.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 3a406acef07..6505cd662f1 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -18,7 +18,6 @@ #include "oneapi/dal/table/backend/convert.hpp" #include -#include namespace oneapi::dal::backend { From feade7c7b31bf23f8295dd1287316acacd6ff7eb Mon Sep 17 00:00:00 2001 From: Victoriya Fedotova Date: Wed, 22 May 2024 01:34:36 -0700 Subject: [PATCH 4/4] clang-format; revert changes in tests --- .../primitives/sparse_blas/test/gemm_dpc.cpp | 3 ++- .../primitives/sparse_blas/test/gemv_dpc.cpp | 3 ++- cpp/oneapi/dal/table/backend/csr_kernels.cpp | 20 +++++++++---------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp index cea8206928c..3faac011d35 100644 --- a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemm_dpc.cpp @@ -32,7 +32,8 @@ TEMPLATE_LIST_TEST_M(sparse_blas_test, "ones matrix sparse CSR gemm", "[csr][gem SKIP_IF(this->get_policy().is_cpu()); // Test takes too long time if HW emulates float64 - SKIP_IF(!this->not_float64_friendly()); + // Temporary workaround: skip tests on architectures that do not support native float64 + SKIP_IF(!this->get_policy().has_native_float64()); this->generate_dimensions(); this->test_gemm(); diff --git a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp index da256eb5554..4695095b2a4 100644 --- a/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/sparse_blas/test/gemv_dpc.cpp @@ -31,7 +31,8 @@ TEMPLATE_LIST_TEST_M(sparse_blas_test, "ones matrix sparse CSR gemv", "[csr][gem SKIP_IF(this->get_policy().is_cpu()); // Test takes too long time if HW emulates float64 - SKIP_IF(!this->not_float64_friendly()); + // Temporary workaround: skip tests on architectures that do not support native float64 + SKIP_IF(!this->get_policy().has_native_float64()); this->generate_dimensions_gemv(); this->test_gemv(); diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 6505cd662f1..8e5aef236b2 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -496,16 +496,16 @@ out_of_bound_type check_bounds(const array& arr, sycl::reduction(count_gt_buf, cgh, sycl::ext::oneapi::plus()); cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count) }, - count_lt_reduction, - count_gt_reduction, - [=](sycl::id<1> i, auto& count_lt, auto &count_gt) { - if (data[i] < min_value) { - count_lt.combine(1); - } - if (data[i] > max_value) { - count_gt.combine(1); - } - }); + count_lt_reduction, + count_gt_reduction, + [=](sycl::id<1> i, auto& count_lt, auto& count_gt) { + if (data[i] < min_value) { + count_lt.combine(1); + } + if (data[i] > max_value) { + count_gt.combine(1); + } + }); }) .wait_and_throw();