From 0faf8894085155e1eabd13e20af5ccfcf22e363c Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 2 Oct 2024 11:56:13 -0700 Subject: [PATCH 01/15] [Feat] CAGRA filtering with BFKNN when sparsity matching threshold --- cpp/include/cuvs/neighbors/cagra.hpp | 27 ++++++-- .../neighbors/detail/cagra/cagra_search.cuh | 56 +++++++++++++++ cpp/test/neighbors/ann_cagra.cuh | 68 +++++++++++++------ .../ann_cagra/test_float_uint32_t.cu | 3 + 4 files changed, 131 insertions(+), 23 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index e48050756..5b7a5ab0f 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -403,6 +403,13 @@ struct index : cuvs::neighbors::index { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); update_graph(res, knn_graph); + if constexpr (raft::is_device_mdspan_v) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } else { + contiguous_dataset_ = + raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } raft::resource::sync_stream(res); } @@ -417,13 +424,16 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - dataset_ = make_aligned_dataset(res, dataset, 16); + contiguous_dataset_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); dataset_ = make_aligned_dataset(res, dataset, 16); } @@ -436,7 +446,8 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - dataset_ = make_aligned_dataset(res, dataset, 16); + contiguous_dataset_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -447,14 +458,16 @@ struct index : cuvs::neighbors::index { auto update_dataset(raft::resources const& res, DatasetT&& dataset) -> std::enable_if_t, DatasetT>> { - dataset_ = std::make_unique(std::move(dataset)); + contiguous_dataset_ = std::monostate{}; + dataset_ = std::make_unique(std::move(dataset)); } template auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) -> std::enable_if_t, DatasetT>> { - dataset_ = std::move(dataset); + contiguous_dataset_ = std::monostate{}; + dataset_ = std::move(dataset); } /** @@ -492,11 +505,17 @@ struct index : cuvs::neighbors::index { graph_view_ = graph_.view(); } + auto contiguous_dataset() const { return contiguous_dataset_; } + private: cuvs::distance::DistanceType metric_; raft::device_matrix graph_; raft::device_matrix_view graph_view_; std::unique_ptr> dataset_; + std::variant, + raft::host_matrix_view> + contiguous_dataset_ = std::monostate{}; }; /** * @} diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 4c15b8e14..5a1b764d0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -26,9 +26,11 @@ #include #include #include +#include #include +#include #include // TODO: Fix these when ivf methods are moved over @@ -140,6 +142,60 @@ void search_main(raft::resources const& res, raft::device_matrix_view distances, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { + if constexpr (!std::is_same_v && + (std::is_same_v || std::is_same_v)) { + auto n_queries = queries.extent(0); + auto n_dataset = index.size(); + + auto bitset_filter_view = sample_filter.bitset_view_; + auto dataset_view = index.contiguous_dataset(); + + auto sparsity = bitset_filter_view.sparsity(res); + constexpr double threshold_to_bf = 0.9; + + // TODO: Support host dataset in `brute_force::build` + if (sparsity >= threshold_to_bf && + std::holds_alternative>( + dataset_view)) { + using bitmap_view_t = cuvs::core::bitmap_view; + + auto stream = raft::resource::get_cuda_stream(res); + auto bitmap_n_elements = + bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); + + rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); + rmm::device_uvector raw_neighbors(neighbors.size(), stream); + + bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); + + auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); + + auto brute_force_neighbors = raft::make_device_matrix_view( + raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); + auto brute_force_dataset = + std::get_if>(&dataset_view); + + if (brute_force_dataset) { + auto brute_force_idx = + cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + return; + } + } + } + auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); auto graph_internal = raft::make_device_matrix_view( diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 37d42dd1d..512e7a60d 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra { namespace { struct test_cagra_sample_filter { - static constexpr unsigned offset = 300; inline _RAFT_HOST_DEVICE auto operator()( // query index const uint32_t query_ix, // the index of the current sample inside the current inverted list - const uint32_t sample_ix) const + const uint32_t sample_ix, + const uint32_t offset) const { return sample_ix >= offset; } @@ -276,6 +276,7 @@ struct AnnCagraInputs { bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; + uint32_t filter_offset = 300; std::optional ivf_pq_search_refine_ratio = std::nullopt; std::optional compression = std::nullopt; @@ -702,21 +703,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; - cuvs::neighbors::naive_knn( - handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.n_queries, - ps.n_rows - test_cagra_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); + auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim; + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - ps.filter_offset, + ps.dim, + ps.k, + ps.metric); raft::linalg::addScalar(indices_naive_dev.data(), indices_naive_dev.data(), - IdxT(test_cagra_sample_filter::offset), + IdxT(ps.filter_offset), queries_size, stream_); raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); @@ -787,7 +787,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto dists_out_view = raft::make_device_matrix_view( distances_dev.data(), ps.n_queries, ps.k); auto removed_indices = - raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + raft::make_device_vector(handle_, ps.filter_offset); thrust::sequence( raft::resource::get_thrust_policy(handle_), thrust::device_pointer_cast(removed_indices.data_handle()), @@ -813,8 +813,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { bool unacceptable_node = false; for (int q = 0; q < ps.n_queries; q++) { for (int i = 0; i < ps.k; i++) { - const auto n = indices_Cagra[q * ps.k + i]; - unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = + unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset); } } EXPECT_FALSE(unacceptable_node); @@ -1002,6 +1003,7 @@ inline std::vector generate_inputs() {false, true}, {false}, {0.99}, + {uint32_t(300)}, {1.0f, 2.0f, 3.0f}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -1028,6 +1030,34 @@ inline std::vector generate_inputs() return inputs; } -const std::vector inputs = generate_inputs(); +inline std::vector generate_bf_inputs() +{ + // Add test cases for brute force as sparsity >= 0.9. + std::vector inputs_for_brute_force; + auto inputs_original = raft::util::itertools::product( + {100}, + {10000, 100000}, + {1, 8, 17}, + {1, 16, 256}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + {0, 1, 10, 100}, + {0}, + {256}, + {1}, + {cuvs::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {1.0}); + for (auto input : inputs_original) { + input.filter_offset = 0.90 * input.n_rows; + inputs_for_brute_force.push_back(input); + } + + return inputs_for_brute_force; +} + +const std::vector inputs = generate_inputs(); +const std::vector inputs_brute_force = generate_bf_inputs(); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index ca188d132..a98c31510 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest, + AnnCagraFilterTestF_U32, + ::testing::ValuesIn(inputs_brute_force)); } // namespace cuvs::neighbors::cagra From f14be712214b975a6e99896611677e184fc2454d Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 3 Oct 2024 16:50:16 -0700 Subject: [PATCH 02/15] revert: update_dataset on strided matrix --- cpp/include/cuvs/neighbors/cagra.hpp | 7 +++++-- cpp/src/neighbors/detail/cagra/cagra_search.cuh | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 5b7a5ab0f..83d9eec12 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -432,8 +432,11 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - contiguous_dataset_ = - raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + contiguous_dataset_ = std::monostate{}; + if (dataset.stride(0) == dataset.extent(1) && dataset.stride(1) == 1) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } dataset_ = make_aligned_dataset(res, dataset, 16); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5a1b764d0..ba0d82831 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -177,6 +177,7 @@ void search_main(raft::resources const& res, std::get_if>(&dataset_view); if (brute_force_dataset) { + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); auto brute_force_idx = cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); cuvs::neighbors::brute_force::search( From 5378827847d89d68d549278813d9982e6d264b0c Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 29 Oct 2024 15:02:41 -0700 Subject: [PATCH 03/15] Support strided matrix on queries & respond to the review comments --- cpp/include/cuvs/neighbors/cagra.hpp | 7 +- .../neighbors/detail/cagra/cagra_search.cuh | 158 ++++++++++++------ cpp/test/neighbors/ann_cagra.cuh | 21 ++- .../neighbors/ann_cagra/test_half_uint32_t.cu | 7 + 4 files changed, 135 insertions(+), 58 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 83d9eec12..fbc9d669f 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -432,11 +432,8 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - contiguous_dataset_ = std::monostate{}; - if (dataset.stride(0) == dataset.extent(1) && dataset.stride(1) == 1) { - contiguous_dataset_ = - raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - } + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.stride(0)); dataset_ = make_aligned_dataset(res, dataset, 16); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index d0df0613e..7c31d54aa 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -21,6 +21,7 @@ #include "sample_filter_utils.cuh" #include "search_plan.cuh" #include "search_single_cta_inst.cuh" +#include "utils.hpp" #include #include @@ -110,6 +111,109 @@ void search_main_core(raft::resources const& res, } } +/** + * @brief Performs ANN search using brute force when filter sparsity exceeds a specified threshold. + * + * This function switches to a brute force search approach to improve recall rate when the + * `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level + * (proportion of unfiltered samples) exceeding the specified `threshold_to_bf`. + * + * @tparam T data element type + * @tparam IdxT type of database vector indices + * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need + * separate kernels for int/uint. + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] + * + * @return true If the brute force search was applied successfully. + * @return false If the brute force search was not applied. + */ +template +bool search_using_brute_force( + raft::resources const& res, + search_params& params, + const index& index, + raft::device_matrix_view& queries, + raft::device_matrix_view& neighbors, + raft::device_matrix_view& distances, + CagraSampleFilterT& sample_filter, + double threshold_to_bf = 0.9) +{ + bool is_applied = false; + auto n_queries = queries.extent(0); + auto n_dataset = index.size(); + + auto bitset_filter_view = sample_filter.bitset_view_; + auto dataset_view = index.contiguous_dataset(); + auto sparsity = bitset_filter_view.sparsity(res); + + // TODO: Support host dataset in `brute_force::build` + if (sparsity >= threshold_to_bf && + std::holds_alternative>( + dataset_view)) { + using bitmap_view_t = cuvs::core::bitmap_view; + + auto stream = raft::resource::get_cuda_stream(res); + auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); + + rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); + rmm::device_uvector raw_neighbors(neighbors.size(), stream); + + bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); + + auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); + + auto brute_force_neighbors = raft::make_device_matrix_view( + raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); + auto brute_force_dataset = + std::get_if>(&dataset_view); + + if (brute_force_dataset) { + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); + auto brute_force_idx = + cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); + + auto brute_force_queries = queries; + auto padding_queries = raft::make_device_matrix(res, 0, 0); + + // Happens when the original dataset is a strided matrix. + if (brute_force_dataset->extent(1) != queries.extent(1)) { + cuvs::neighbors::cagra::detail::copy_with_padding(res, padding_queries, queries); + brute_force_queries = raft::make_device_matrix_view( + padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); + } + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + brute_force_queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + is_applied = true; + } + } + return is_applied; +} + /** * @brief Search ANN using the constructed index. * @@ -128,6 +232,7 @@ void search_main_core(raft::resources const& res, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template && (std::is_same_v || std::is_same_v)) { - auto n_queries = queries.extent(0); - auto n_dataset = index.size(); - - auto bitset_filter_view = sample_filter.bitset_view_; - auto dataset_view = index.contiguous_dataset(); - - auto sparsity = bitset_filter_view.sparsity(res); - constexpr double threshold_to_bf = 0.9; - - // TODO: Support host dataset in `brute_force::build` - if (sparsity >= threshold_to_bf && - std::holds_alternative>( - dataset_view)) { - using bitmap_view_t = cuvs::core::bitmap_view; - - auto stream = raft::resource::get_cuda_stream(res); - auto bitmap_n_elements = - bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); - - rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); - rmm::device_uvector raw_neighbors(neighbors.size(), stream); - - bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); - - auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); - - auto brute_force_neighbors = raft::make_device_matrix_view( - raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); - auto brute_force_dataset = - std::get_if>(&dataset_view); - - if (brute_force_dataset) { - RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); - auto brute_force_idx = - cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); - cuvs::neighbors::brute_force::search( - res, - brute_force_idx, - queries, - brute_force_neighbors, - distances, - cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); - raft::linalg::unaryOp(neighbors.data_handle(), - brute_force_neighbors.data_handle(), - neighbors.size(), - raft::cast_op(), - raft::resource::get_cuda_stream(res)); - return; - } - } + bool bf_search_done = + search_using_brute_force(res, params, index, queries, neighbors, distances, sample_filter); + if (bf_search_done) return; } auto stream = raft::resource::get_cuda_stream(res); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 512e7a60d..7c7e57e36 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -23,6 +23,9 @@ #include #include + +#include "../../../../src/neighbors/detail/cagra/utils.hpp" + #include #include #include @@ -780,6 +783,18 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + auto dataset_padding = raft::make_device_matrix(handle_, 0, 0); + if ((sizeof(DataT) * ps.dim % 16) != 0) { + cuvs::neighbors::cagra::detail::copy_with_padding( + handle_, dataset_padding, database_view); + auto database_view = raft::make_device_strided_matrix_view( + dataset_padding.data_handle(), + dataset_padding.extent(0), + ps.dim, + dataset_padding.extent(1)); + index.update_dataset(handle_, database_view); + } + auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.n_queries, ps.dim); auto indices_out_view = @@ -1036,9 +1051,9 @@ inline std::vector generate_bf_inputs() std::vector inputs_for_brute_force; auto inputs_original = raft::util::itertools::product( {100}, - {10000, 100000}, - {1, 8, 17}, - {1, 16, 256}, // k + {1000}, + {1, 7, 8, 17}, + {1, 16}, // k {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, diff --git a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu index f03de69d2..521e12a03 100644 --- a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu @@ -23,6 +23,13 @@ namespace cuvs::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF16_U32; TEST_P(AnnCagraTestF16_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraFilterTest AnnCagraFilterTestF16_U32; +TEST_P(AnnCagraFilterTestF16_U32, AnnCagra) { this->testCagra(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF16_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF16_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest, + AnnCagraFilterTestF16_U32, + ::testing::ValuesIn(inputs_brute_force)); } // namespace cuvs::neighbors::cagra From 757c222df570c67ceeb84734fe30e932081a04d9 Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 29 Oct 2024 15:36:07 -0700 Subject: [PATCH 04/15] fix a style issue --- cpp/test/neighbors/ann_cagra.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 7c7e57e36..c856a4d6c 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -24,7 +24,7 @@ #include #include -#include "../../../../src/neighbors/detail/cagra/utils.hpp" +#include <../../../../src/neighbors/detail/cagra/utils.hpp> #include #include From caab88b214e983959a92af4448d76fd095f8e690 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 30 Oct 2024 09:42:02 -0700 Subject: [PATCH 05/15] fix: don't invoke 'copy_with_padding' from `src/neighbors/detail` --- cpp/test/neighbors/ann_cagra.cuh | 39 ++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index c856a4d6c..9d21e6728 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -23,9 +23,6 @@ #include #include - -#include <../../../../src/neighbors/detail/cagra/utils.hpp> - #include #include #include @@ -98,6 +95,39 @@ void RandomSuffle(raft::host_matrix_view index) } } +template +void copy_with_padding( + raft::resources const& res, + raft::device_matrix& dst, + raft::mdspan, raft::row_major, data_accessor> src, + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) +{ + size_t padded_dim = raft::round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); + + if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { + // clear existing memory before allocating to prevent OOM errors on large datasets + if (dst.size()) { dst = raft::make_device_matrix(res, 0, 0); } + dst = + raft::make_device_mdarray(res, mr, raft::make_extents(src.extent(0), padded_dim)); + } + if (dst.extent(1) == src.extent(1)) { + raft::copy( + dst.data_handle(), src.data_handle(), src.size(), raft::resource::get_cuda_stream(res)); + } else { + // copy with padding + RAFT_CUDA_TRY(cudaMemsetAsync( + dst.data_handle(), 0, dst.size() * sizeof(T), raft::resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(T) * dst.extent(1), + src.data_handle(), + sizeof(T) * src.extent(1), + sizeof(T) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + } +} + template testing::AssertionResult CheckOrder(raft::host_matrix_view index_test, raft::host_matrix_view dataset) @@ -785,8 +815,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto dataset_padding = raft::make_device_matrix(handle_, 0, 0); if ((sizeof(DataT) * ps.dim % 16) != 0) { - cuvs::neighbors::cagra::detail::copy_with_padding( - handle_, dataset_padding, database_view); + copy_with_padding(handle_, dataset_padding, database_view); auto database_view = raft::make_device_strided_matrix_view( dataset_padding.data_handle(), dataset_padding.extent(0), From f4c19224d2db366bd377f8249ad72c58c43b5800 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 31 Oct 2024 14:45:34 -0700 Subject: [PATCH 06/15] optimize by review comments --- cpp/include/cuvs/neighbors/cagra.hpp | 27 +++---------------- .../neighbors/detail/cagra/cagra_search.cuh | 26 +++++++++++++----- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index fbc9d669f..e48050756 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -403,13 +403,6 @@ struct index : cuvs::neighbors::index { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); update_graph(res, knn_graph); - if constexpr (raft::is_device_mdspan_v) { - contiguous_dataset_ = - raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - } else { - contiguous_dataset_ = - raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - } raft::resource::sync_stream(res); } @@ -424,16 +417,13 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - contiguous_dataset_ = dataset; - dataset_ = make_aligned_dataset(res, dataset, 16); + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - contiguous_dataset_ = - raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.stride(0)); dataset_ = make_aligned_dataset(res, dataset, 16); } @@ -446,8 +436,7 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - contiguous_dataset_ = dataset; - dataset_ = make_aligned_dataset(res, dataset, 16); + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -458,16 +447,14 @@ struct index : cuvs::neighbors::index { auto update_dataset(raft::resources const& res, DatasetT&& dataset) -> std::enable_if_t, DatasetT>> { - contiguous_dataset_ = std::monostate{}; - dataset_ = std::make_unique(std::move(dataset)); + dataset_ = std::make_unique(std::move(dataset)); } template auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) -> std::enable_if_t, DatasetT>> { - contiguous_dataset_ = std::monostate{}; - dataset_ = std::move(dataset); + dataset_ = std::move(dataset); } /** @@ -505,17 +492,11 @@ struct index : cuvs::neighbors::index { graph_view_ = graph_.view(); } - auto contiguous_dataset() const { return contiguous_dataset_; } - private: cuvs::distance::DistanceType metric_; raft::device_matrix graph_; raft::device_matrix_view graph_view_; std::unique_ptr> dataset_; - std::variant, - raft::host_matrix_view> - contiguous_dataset_ = std::monostate{}; }; /** * @} diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 7c31d54aa..2917377f6 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -145,11 +145,11 @@ template bool search_using_brute_force( raft::resources const& res, - search_params& params, + const search_params& params, const index& index, - raft::device_matrix_view& queries, - raft::device_matrix_view& neighbors, - raft::device_matrix_view& distances, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, CagraSampleFilterT& sample_filter, double threshold_to_bf = 0.9) { @@ -158,8 +158,19 @@ bool search_using_brute_force( auto n_dataset = index.size(); auto bitset_filter_view = sample_filter.bitset_view_; - auto dataset_view = index.contiguous_dataset(); - auto sparsity = bitset_filter_view.sparsity(res); + // auto dataset_view = index.contiguous_dataset(); + auto dataset_view = [&index]() + -> std::variant, std::monostate> { + using ds_idx_type = decltype(index.data().n_rows()); + if (auto* strided_dset = dynamic_cast*>(&index.data()); + strided_dset != nullptr) { + return raft::make_device_matrix_view( + strided_dset->view().data_handle(), strided_dset->n_rows(), strided_dset->stride()); + } else { + return std::monostate{}; + } + }(); + auto sparsity = bitset_filter_view.sparsity(res); // TODO: Support host dataset in `brute_force::build` if (sparsity >= threshold_to_bf && @@ -192,7 +203,8 @@ bool search_using_brute_force( // Happens when the original dataset is a strided matrix. if (brute_force_dataset->extent(1) != queries.extent(1)) { - cuvs::neighbors::cagra::detail::copy_with_padding(res, padding_queries, queries); + cuvs::neighbors::cagra::detail::copy_with_padding( + res, padding_queries, queries, raft::resource::get_workspace_resource(res)); brute_force_queries = raft::make_device_matrix_view( padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); } From a73ba1f813b60e671f7ea9e5640c12d278160cf2 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 13 Nov 2024 16:15:03 -0800 Subject: [PATCH 07/15] move calling down to branch & replace copy_with_padding --- .../neighbors/detail/cagra/cagra_search.cuh | 110 +++++++++--------- 1 file changed, 53 insertions(+), 57 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 7194401e0..8081eacc4 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -125,7 +125,8 @@ void search_main_core(raft::resources const& res, * * @param[in] handle * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index + * @param[in] strided_dataset CAGRA strided dataset + * @param[in] metric distance type * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] @@ -146,7 +147,8 @@ template & index, + const strided_dataset& strided_dataset, + cuvs::distance::DistanceType metric, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, @@ -155,27 +157,14 @@ bool search_using_brute_force( { bool is_applied = false; auto n_queries = queries.extent(0); - auto n_dataset = index.size(); + auto n_dataset = strided_dataset.n_rows(); auto bitset_filter_view = sample_filter.bitset_view_; - // auto dataset_view = index.contiguous_dataset(); - auto dataset_view = [&index]() - -> std::variant, std::monostate> { - using ds_idx_type = decltype(index.data().n_rows()); - if (auto* strided_dset = dynamic_cast*>(&index.data()); - strided_dset != nullptr) { - return raft::make_device_matrix_view( - strided_dset->view().data_handle(), strided_dset->n_rows(), strided_dset->stride()); - } else { - return std::monostate{}; - } - }(); - auto sparsity = bitset_filter_view.sparsity(res); + auto sparsity = bitset_filter_view.sparsity(res); // TODO: Support host dataset in `brute_force::build` - if (sparsity >= threshold_to_bf && - std::holds_alternative>( - dataset_view)) { + if (sparsity >= threshold_to_bf) { + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); using bitmap_view_t = cuvs::core::bitmap_view; auto stream = raft::resource::get_cuda_stream(res); @@ -190,38 +179,45 @@ bool search_using_brute_force( auto brute_force_neighbors = raft::make_device_matrix_view( raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); - auto brute_force_dataset = - std::get_if>(&dataset_view); - - if (brute_force_dataset) { - RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); - auto brute_force_idx = - cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); - - auto brute_force_queries = queries; - auto padding_queries = raft::make_device_matrix(res, 0, 0); - - // Happens when the original dataset is a strided matrix. - if (brute_force_dataset->extent(1) != queries.extent(1)) { - cuvs::neighbors::cagra::detail::copy_with_padding( - res, padding_queries, queries, raft::resource::get_workspace_resource(res)); - brute_force_queries = raft::make_device_matrix_view( - padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); - } - cuvs::neighbors::brute_force::search( + auto brute_force_dataset = raft::make_device_matrix_view( + strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride()); + + auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric); + + auto brute_force_queries = queries; + auto padding_queries = raft::make_device_matrix(res, 0, 0); + + // Happens when the original dataset is a strided matrix. + if (brute_force_dataset.extent(1) != queries.extent(1)) { + padding_queries = raft::make_device_mdarray( res, - brute_force_idx, - brute_force_queries, - brute_force_neighbors, - distances, - cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); - raft::linalg::unaryOp(neighbors.data_handle(), - brute_force_neighbors.data_handle(), - neighbors.size(), - raft::cast_op(), - raft::resource::get_cuda_stream(res)); - is_applied = true; + raft::resource::get_workspace_resource(res), + raft::make_extents(n_queries, brute_force_dataset.extent(1))); + // Copy the queries and fill the padded elements with zeros + raft::linalg::map_offset( + res, + padding_queries.view(), + [queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) { + auto row_ix = i / stride; + auto el_ix = i % stride; + return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0}; + }); + brute_force_queries = raft::make_device_matrix_view( + padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); } + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + brute_force_queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + is_applied = true; } return is_applied; } @@ -259,14 +255,6 @@ void search_main(raft::resources const& res, raft::device_matrix_view distances, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { - if constexpr (!std::is_same_v && - (std::is_same_v || std::is_same_v)) { - bool bf_search_done = - search_using_brute_force(res, params, index, queries, neighbors, distances, sample_filter); - if (bf_search_done) return; - } - auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); auto graph_internal = raft::make_device_matrix_view( @@ -277,6 +265,14 @@ void search_main(raft::resources const& res, // Dispatch search parameters based on the dataset kind. if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { + if constexpr (!std::is_same_v && + (std::is_same_v || std::is_same_v)) { + bool bf_search_done = search_using_brute_force( + res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); + if (bf_search_done) return; + } + // Search using a plain (strided) row-major dataset auto desc = dataset_descriptor_init_with_cache( res, params, *strided_dset, index.metric()); From 00361270dfcc2877ecf1650a83c4533efecf0cb9 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 14 Nov 2024 19:31:10 -0800 Subject: [PATCH 08/15] fix: RAFT_LOG_DEBUG %f for double & other optimization --- .../neighbors/detail/cagra/cagra_search.cuh | 114 +++++++++--------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 8081eacc4..1f2ef295c 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -155,71 +155,69 @@ bool search_using_brute_force( CagraSampleFilterT& sample_filter, double threshold_to_bf = 0.9) { - bool is_applied = false; - auto n_queries = queries.extent(0); - auto n_dataset = strided_dataset.n_rows(); + auto n_queries = queries.extent(0); + auto n_dataset = strided_dataset.n_rows(); auto bitset_filter_view = sample_filter.bitset_view_; auto sparsity = bitset_filter_view.sparsity(res); + if (sparsity < threshold_to_bf) { return false; } + // TODO: Support host dataset in `brute_force::build` - if (sparsity >= threshold_to_bf) { - RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); - using bitmap_view_t = cuvs::core::bitmap_view; - - auto stream = raft::resource::get_cuda_stream(res); - auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); - - rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); - rmm::device_uvector raw_neighbors(neighbors.size(), stream); - - bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); - - auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); - - auto brute_force_neighbors = raft::make_device_matrix_view( - raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); - auto brute_force_dataset = raft::make_device_matrix_view( - strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride()); - - auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric); - - auto brute_force_queries = queries; - auto padding_queries = raft::make_device_matrix(res, 0, 0); - - // Happens when the original dataset is a strided matrix. - if (brute_force_dataset.extent(1) != queries.extent(1)) { - padding_queries = raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource(res), - raft::make_extents(n_queries, brute_force_dataset.extent(1))); - // Copy the queries and fill the padded elements with zeros - raft::linalg::map_offset( - res, - padding_queries.view(), - [queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) { - auto row_ix = i / stride; - auto el_ix = i % stride; - return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0}; - }); - brute_force_queries = raft::make_device_matrix_view( - padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); - } - cuvs::neighbors::brute_force::search( + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity); + using bitmap_view_t = cuvs::core::bitmap_view; + + auto stream = raft::resource::get_cuda_stream(res); + auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); + + rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); + rmm::device_uvector raw_neighbors(neighbors.size(), stream); + + bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); + + auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); + + auto brute_force_neighbors = raft::make_device_matrix_view( + raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); + auto brute_force_dataset = raft::make_device_matrix_view( + strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride()); + + auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric); + + auto brute_force_queries = queries; + auto padding_queries = raft::make_device_matrix(res, 0, 0); + + // Happens when the original dataset is a strided matrix. + if (brute_force_dataset.extent(1) != queries.extent(1)) { + padding_queries = raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(n_queries, brute_force_dataset.extent(1))); + // Copy the queries and fill the padded elements with zeros + raft::linalg::map_offset( res, - brute_force_idx, - brute_force_queries, - brute_force_neighbors, - distances, - cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); - raft::linalg::unaryOp(neighbors.data_handle(), - brute_force_neighbors.data_handle(), - neighbors.size(), - raft::cast_op(), - raft::resource::get_cuda_stream(res)); - is_applied = true; + padding_queries.view(), + [queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) { + auto row_ix = i / stride; + auto el_ix = i % stride; + return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0}; + }); + brute_force_queries = raft::make_device_matrix_view( + padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); } - return is_applied; + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + brute_force_queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + return true; } /** From b5dcc02c06dc4fb5b18ed0ac73986630fc9a6391 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 18 Nov 2024 06:42:08 -0800 Subject: [PATCH 09/15] benchmark: support pre-filter on CAGRA --- cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h | 67 +++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index b2ba35eee..271d0bcbd 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include @@ -52,10 +53,13 @@ namespace cuvs::bench { enum class AllocatorType { kHostPinned, kHostHugePage, kDevice }; enum class CagraBuildAlgo { kAuto, kIvfPq, kNnDescent }; +constexpr double sparsity = 0.0f; + template class cuvs_cagra : public algo, public algo_gpu { public: using search_param_base = typename algo::search_param; + // TODO: Move to arguments struct search_param : public search_param_base { cuvs::neighbors::cagra::search_params p; @@ -91,6 +95,40 @@ class cuvs_cagra : public algo, public algo_gpu { } }; + int64_t create_sparse_bitset(int64_t total, float sparsity, std::vector& bitset) const + { + int64_t num_ones = static_cast((total * 1.0f) * (1.0f - sparsity)); + int64_t res = num_ones; + + for (auto& item : bitset) { + item = static_cast(0); + } + + if (sparsity == 0.0) { + for (auto& item : bitset) { + item = static_cast(0xffffffff); + } + return total; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + int64_t index = dis(gen); + + uint32_t& element = bitset[index / (8 * sizeof(uint32_t))]; + int64_t bit_position = index % (8 * sizeof(uint32_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + cuvs_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) : algo(metric, dim), index_params_(param), @@ -102,8 +140,9 @@ class cuvs_cagra : public algo, public algo_gpu { std::move(raft::make_device_matrix(handle_, 0, 0)))), input_dataset_v_( std::make_shared>( - nullptr, 0, 0)) - + nullptr, 0, 0)), + bitset_filter_(std::make_shared>( + std::move(cuvs::core::bitset(handle_, 0, false)))) { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); @@ -171,6 +210,9 @@ class cuvs_cagra : public algo, public algo_gpu { std::shared_ptr> dataset_; std::shared_ptr> input_dataset_v_; + // std::shared_ptr> bitset_filter_; + std::shared_ptr> bitset_filter_; + inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type) { switch (mem_type) { @@ -256,6 +298,16 @@ void cuvs_cagra::set_search_param(const search_param_base& param) need_dataset_update_ = false; } + + { // create bitset filter in advance. + auto stream_ = raft::resource::get_cuda_stream(handle_); + size_t filter_n_elements = size_t((input_dataset_v_->extent(0) + 31) / 32); + std::cout << "input_dataset_v_->extent(0): " << input_dataset_v_->extent(0) << std::endl; + bitset_filter_->resize(handle_, input_dataset_v_->extent(0), false); + std::vector bitset_cpu(filter_n_elements); + create_sparse_bitset(input_dataset_v_->extent(0), sparsity, bitset_cpu); + raft::copy(bitset_filter_->data(), bitset_cpu.data(), filter_n_elements, stream_); + } } template @@ -328,8 +380,15 @@ void cuvs_cagra::search_base(const T* queries, raft::make_device_matrix_view(neighbors_idx_t, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - cuvs::neighbors::cagra::search( - handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); + if constexpr ((std::is_same_v || std::is_same_v)&&sparsity >= 0.0f) { + auto filter = cuvs::neighbors::filtering::bitset_filter(bitset_filter_->view()); + cuvs::neighbors::cagra::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, filter); + + } else { + cuvs::neighbors::cagra::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); + } if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) { if (raft::get_device_for_address(neighbors) < 0 && From 5c9c5de8b49b79a18579dc6af67777a9cf887b17 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 18 Nov 2024 13:46:24 -0800 Subject: [PATCH 10/15] adjust the kernel selection condition to be 0.9f --- cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h | 1 - cpp/src/neighbors/detail/knn_brute_force.cuh | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 271d0bcbd..30ca6b722 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -302,7 +302,6 @@ void cuvs_cagra::set_search_param(const search_param_base& param) { // create bitset filter in advance. auto stream_ = raft::resource::get_cuda_stream(handle_); size_t filter_n_elements = size_t((input_dataset_v_->extent(0) + 31) / 32); - std::cout << "input_dataset_v_->extent(0): " << input_dataset_v_->extent(0) << std::endl; bitset_filter_->resize(handle_, input_dataset_v_->extent(0), false); std::vector bitset_cpu(filter_n_elements); create_sparse_bitset(input_dataset_v_->extent(0), sparsity, bitset_cpu); diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index e5eeecbc9..24fe0651a 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -620,9 +620,9 @@ void brute_force_search_filtered( raft::copy(&nnz_h, nnz.data(), 1, stream); raft::resource::sync_stream(res, stream); - float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); + float sparsity = (1.0f - (1.0f * nnz_h) / (1.0f * n_queries * n_dataset)); - if (sparsity > 0.01f) { + if (sparsity < 0.9f) { raft::resources stream_pool_handle(res); raft::resource::set_cuda_stream(stream_pool_handle, stream); auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; From d190b9db096f24b04e78a8da0bf14754d2bacfe9 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 18 Nov 2024 16:10:35 -0800 Subject: [PATCH 11/15] expose the threshold-to-bf to callers & test cases --- cpp/include/cuvs/neighbors/cagra.hpp | 20 ++++++++++++---- cpp/src/neighbors/cagra.cuh | 23 ++++++++++++++----- cpp/src/neighbors/cagra_search_float.cu | 23 ++++++++++--------- cpp/src/neighbors/cagra_search_half.cu | 23 ++++++++++--------- cpp/src/neighbors/cagra_search_int8.cu | 23 ++++++++++--------- cpp/src/neighbors/cagra_search_uint8.cu | 23 ++++++++++--------- .../neighbors/detail/cagra/cagra_search.cuh | 17 +++++++++++--- cpp/test/neighbors/ann_cagra.cuh | 11 ++++++--- 8 files changed, 103 insertions(+), 60 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index e48050756..8ff9664f6 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1057,6 +1057,8 @@ void extend( * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] */ void search(raft::resources const& res, @@ -1066,7 +1068,8 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}); + cuvs::neighbors::filtering::none_sample_filter{}, + double threshold_to_bf = 0.9f); /** * @brief Search ANN using the constructed index. @@ -1083,6 +1086,8 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1091,7 +1096,8 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}); + cuvs::neighbors::filtering::none_sample_filter{}, + double threshold_to_bf = 0.9f); /** * @brief Search ANN using the constructed index. @@ -1108,6 +1114,8 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1116,7 +1124,8 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}); + cuvs::neighbors::filtering::none_sample_filter{}, + double threshold_to_bf = 0.9f); /** * @brief Search ANN using the constructed index. @@ -1133,6 +1142,8 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1141,7 +1152,8 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}); + cuvs::neighbors::filtering::none_sample_filter{}, + double threshold_to_bf = 0.9f); /** * @} diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index dacfd6f63..f0bad1431 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -293,6 +293,9 @@ index build( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] + * */ template void search_with_filtering(raft::resources const& res, @@ -301,7 +304,8 @@ void search_with_filtering(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) + CagraSampleFilterT sample_filter = CagraSampleFilterT(), + double threshold_to_bf = 0.9) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -322,8 +326,14 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - return cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); + return cagra::detail::search_main(res, + params, + idx, + queries_internal, + neighbors_internal, + distances_internal, + sample_filter, + threshold_to_bf); } template @@ -333,14 +343,15 @@ void search(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - const cuvs::neighbors::filtering::base_filter& sample_filter_ref) + const cuvs::neighbors::filtering::base_filter& sample_filter_ref, + double threshold_to_bf = 0.9) { try { using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; auto& sample_filter = dynamic_cast(sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy); + res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); return; } catch (const std::bad_cast&) { } @@ -351,7 +362,7 @@ void search(raft::resources const& res, sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy); + res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/cagra_search_float.cu b/cpp/src/neighbors/cagra_search_float.cu index 3aca84f74..d1d790121 100644 --- a/cpp/src/neighbors/cagra_search_float.cu +++ b/cpp/src/neighbors/cagra_search_float.cu @@ -19,17 +19,18 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter, \ + double threshold_to_bf) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ } CUVS_INST_CAGRA_SEARCH(float, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_half.cu b/cpp/src/neighbors/cagra_search_half.cu index 02be12731..5112e25dd 100644 --- a/cpp/src/neighbors/cagra_search_half.cu +++ b/cpp/src/neighbors/cagra_search_half.cu @@ -19,17 +19,18 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter, \ + double threshold_to_bf) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ } CUVS_INST_CAGRA_SEARCH(half, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_int8.cu b/cpp/src/neighbors/cagra_search_int8.cu index 3442ef55f..a8bfaa7a7 100644 --- a/cpp/src/neighbors/cagra_search_int8.cu +++ b/cpp/src/neighbors/cagra_search_int8.cu @@ -18,17 +18,18 @@ #include namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter, \ + double threshold_to_bf) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ } CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_uint8.cu b/cpp/src/neighbors/cagra_search_uint8.cu index 08fe1861b..411ff9c79 100644 --- a/cpp/src/neighbors/cagra_search_uint8.cu +++ b/cpp/src/neighbors/cagra_search_uint8.cu @@ -19,17 +19,18 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter, \ + double threshold_to_bf) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ } CUVS_INST_CAGRA_SEARCH(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 1f2ef295c..555f91e75 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -239,6 +239,9 @@ bool search_using_brute_force( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query + * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this + * threshold, in the range [0, 1] + * */ template queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) + CagraSampleFilterT sample_filter = CagraSampleFilterT(), + double threshold_to_bf = 0.9) { auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); @@ -266,8 +270,15 @@ void search_main(raft::resources const& res, if constexpr (!std::is_same_v && (std::is_same_v || std::is_same_v)) { - bool bf_search_done = search_using_brute_force( - res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); + bool bf_search_done = search_using_brute_force(res, + params, + *strided_dset, + index.metric(), + queries, + neighbors, + distances, + sample_filter, + threshold_to_bf); if (bf_search_done) return; } diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 1ba03e5f8..5f56f24a6 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -309,6 +309,7 @@ struct AnnCagraInputs { bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; + double threshold_to_bf = 0.9; uint32_t filter_offset = 300; std::optional ivf_pq_search_refine_ratio = std::nullopt; std::optional compression = std::nullopt; @@ -847,7 +848,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_queries_view, indices_out_view, dists_out_view, - bitset_filter_obj); + bitset_filter_obj, + ps.threshold_to_bf); raft::update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); @@ -1047,6 +1049,7 @@ inline std::vector generate_inputs() {false, true}, {false}, {0.99}, + {0.9}, {uint32_t(300)}, {1.0f, 2.0f, 3.0f}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -1092,9 +1095,11 @@ inline std::vector generate_bf_inputs() {cuvs::distance::DistanceType::L2Expanded}, {false}, {true}, - {1.0}); + {1.0}, + {0.1, 0.4, 0.91}); for (auto input : inputs_original) { - input.filter_offset = 0.90 * input.n_rows; + input.filter_offset = 0.5 * input.n_rows; + input.min_recall = input.threshold_to_bf <= 0.5 ? 1.0 : 0.6; inputs_for_brute_force.push_back(input); } From 9aa1bb10566428077fb37737341eb6ae67ac45bc Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 19 Nov 2024 13:10:15 -0800 Subject: [PATCH 12/15] move the threshold-to-bf into search_params --- cpp/include/cuvs/neighbors/cagra.hpp | 24 ++++++----------- cpp/src/neighbors/cagra.cuh | 23 +++++----------- cpp/src/neighbors/cagra_search_float.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_half.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_int8.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_uint8.cu | 23 ++++++++-------- .../neighbors/detail/cagra/cagra_search.cuh | 26 +++++-------------- cpp/test/neighbors/ann_cagra.cuh | 6 ++--- 8 files changed, 67 insertions(+), 104 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8ff9664f6..2150f4214 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -229,6 +229,10 @@ struct search_params : cuvs::neighbors::search_params { * impact on the throughput. */ float persistent_device_usage = 1.0; + + /** A sparsity threshold; brute force is used when sparsity exceeds this threshold, in the range + * [0, 1] */ + double threshold_to_bf = 0.9; }; /** @@ -1057,8 +1061,6 @@ void extend( * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, @@ -1068,8 +1070,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1086,8 +1087,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1096,8 +1095,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1114,8 +1112,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1124,8 +1120,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1142,8 +1137,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1152,8 +1145,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @} diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index f0bad1431..dacfd6f63 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -293,9 +293,6 @@ index build( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] - * */ template void search_with_filtering(raft::resources const& res, @@ -304,8 +301,7 @@ void search_with_filtering(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT(), - double threshold_to_bf = 0.9) + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -326,14 +322,8 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - return cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - sample_filter, - threshold_to_bf); + return cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } template @@ -343,15 +333,14 @@ void search(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - const cuvs::neighbors::filtering::base_filter& sample_filter_ref, - double threshold_to_bf = 0.9) + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { try { using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; auto& sample_filter = dynamic_cast(sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); + res, params, idx, queries, neighbors, distances, sample_filter_copy); return; } catch (const std::bad_cast&) { } @@ -362,7 +351,7 @@ void search(raft::resources const& res, sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); + res, params, idx, queries, neighbors, distances, sample_filter_copy); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/cagra_search_float.cu b/cpp/src/neighbors/cagra_search_float.cu index d1d790121..3aca84f74 100644 --- a/cpp/src/neighbors/cagra_search_float.cu +++ b/cpp/src/neighbors/cagra_search_float.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(float, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_half.cu b/cpp/src/neighbors/cagra_search_half.cu index 5112e25dd..02be12731 100644 --- a/cpp/src/neighbors/cagra_search_half.cu +++ b/cpp/src/neighbors/cagra_search_half.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(half, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_int8.cu b/cpp/src/neighbors/cagra_search_int8.cu index a8bfaa7a7..3442ef55f 100644 --- a/cpp/src/neighbors/cagra_search_int8.cu +++ b/cpp/src/neighbors/cagra_search_int8.cu @@ -18,18 +18,17 @@ #include namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_uint8.cu b/cpp/src/neighbors/cagra_search_uint8.cu index 411ff9c79..08fe1861b 100644 --- a/cpp/src/neighbors/cagra_search_uint8.cu +++ b/cpp/src/neighbors/cagra_search_uint8.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 555f91e75..ab8ee12cc 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -116,7 +116,7 @@ void search_main_core(raft::resources const& res, * * This function switches to a brute force search approach to improve recall rate when the * `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level - * (proportion of unfiltered samples) exceeding the specified `threshold_to_bf`. + * (proportion of unfiltered samples) exceeding the specified `params.threshold_to_bf`. * * @tparam T data element type * @tparam IdxT type of database vector indices @@ -133,8 +133,6 @@ void search_main_core(raft::resources const& res, * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] * * @return true If the brute force search was applied successfully. * @return false If the brute force search was not applied. @@ -152,8 +150,7 @@ bool search_using_brute_force( raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT& sample_filter, - double threshold_to_bf = 0.9) + CagraSampleFilterT& sample_filter) { auto n_queries = queries.extent(0); auto n_dataset = strided_dataset.n_rows(); @@ -161,7 +158,7 @@ bool search_using_brute_force( auto bitset_filter_view = sample_filter.bitset_view_; auto sparsity = bitset_filter_view.sparsity(res); - if (sparsity < threshold_to_bf) { return false; } + if (sparsity < params.threshold_to_bf) { return false; } // TODO: Support host dataset in `brute_force::build` RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity); @@ -239,9 +236,6 @@ bool search_using_brute_force( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] - * */ template queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT(), - double threshold_to_bf = 0.9) + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); @@ -270,15 +263,8 @@ void search_main(raft::resources const& res, if constexpr (!std::is_same_v && (std::is_same_v || std::is_same_v)) { - bool bf_search_done = search_using_brute_force(res, - params, - *strided_dset, - index.metric(), - queries, - neighbors, - distances, - sample_filter, - threshold_to_bf); + bool bf_search_done = search_using_brute_force( + res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); if (bf_search_done) return; } diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 5f56f24a6..fa3e8e855 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -792,6 +792,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.team_size = ps.threshold_to_bf; // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for // k>1024 skip these tests until fixed @@ -848,8 +849,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_queries_view, indices_out_view, dists_out_view, - bitset_filter_obj, - ps.threshold_to_bf); + bitset_filter_obj); raft::update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); @@ -1096,7 +1096,7 @@ inline std::vector generate_bf_inputs() {false}, {true}, {1.0}, - {0.1, 0.4, 0.91}); + {0.1, 0.4, 0.8}); for (auto input : inputs_original) { input.filter_offset = 0.5 * input.n_rows; input.min_recall = input.threshold_to_bf <= 0.5 ? 1.0 : 0.6; From e29d74dde4b562af3cd5d4cbf43d211129706a09 Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 19 Nov 2024 22:28:15 -0800 Subject: [PATCH 13/15] skip the test on half when cusparse version is unsupported. --- cpp/test/neighbors/ann_cagra.cuh | 8 ++++---- cpp/test/neighbors/brute_force_prefiltered.cu | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index fa3e8e855..07f2e81c8 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -789,10 +789,10 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { index_params.compression = ps.compression; cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; - search_params.team_size = ps.threshold_to_bf; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.threshold_to_bf = ps.threshold_to_bf; // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for // k>1024 skip these tests until fixed diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 12b1c529e..7257dcc9e 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -76,6 +76,21 @@ struct CompareApproxWithInf { T eps; }; +bool isCuSparseVersionGreaterThan_12_0_1() +{ + int version; + cusparseHandle_t handle; + cusparseCreate(&handle); + cusparseGetVersion(handle, &version); + + int major = version / 1000; + int minor = (version % 1000) / 100; + int patch = version % 100; + + cusparseDestroy(handle); + + return (major > 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2); +} template RAFT_KERNEL normalize_kernel( OutT* theta, const InT* in_vals, size_t max_scale, size_t r_scale, size_t c_scale) @@ -352,6 +367,9 @@ class PrefilteredBruteForceTest void SetUp() override { + if (std::is_same_v && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } index_t element = raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); std::vector filter_h(element); From 4d0fc8e160bc52357db63c1b9b56fc7e78cd8c65 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 20 Nov 2024 02:35:47 -0800 Subject: [PATCH 14/15] Revert "benchmark: support pre-filter on CAGRA" This reverts commit b5dcc02c06dc4fb5b18ed0ac73986630fc9a6391. --- cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h | 66 ++------------------- 1 file changed, 4 insertions(+), 62 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 30ca6b722..b2ba35eee 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -43,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -53,13 +52,10 @@ namespace cuvs::bench { enum class AllocatorType { kHostPinned, kHostHugePage, kDevice }; enum class CagraBuildAlgo { kAuto, kIvfPq, kNnDescent }; -constexpr double sparsity = 0.0f; - template class cuvs_cagra : public algo, public algo_gpu { public: using search_param_base = typename algo::search_param; - // TODO: Move to arguments struct search_param : public search_param_base { cuvs::neighbors::cagra::search_params p; @@ -95,40 +91,6 @@ class cuvs_cagra : public algo, public algo_gpu { } }; - int64_t create_sparse_bitset(int64_t total, float sparsity, std::vector& bitset) const - { - int64_t num_ones = static_cast((total * 1.0f) * (1.0f - sparsity)); - int64_t res = num_ones; - - for (auto& item : bitset) { - item = static_cast(0); - } - - if (sparsity == 0.0) { - for (auto& item : bitset) { - item = static_cast(0xffffffff); - } - return total; - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, total - 1); - - while (num_ones > 0) { - int64_t index = dis(gen); - - uint32_t& element = bitset[index / (8 * sizeof(uint32_t))]; - int64_t bit_position = index % (8 * sizeof(uint32_t)); - - if (((element >> bit_position) & 1) == 0) { - element |= (static_cast(1) << bit_position); - num_ones--; - } - } - return res; - } - cuvs_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) : algo(metric, dim), index_params_(param), @@ -140,9 +102,8 @@ class cuvs_cagra : public algo, public algo_gpu { std::move(raft::make_device_matrix(handle_, 0, 0)))), input_dataset_v_( std::make_shared>( - nullptr, 0, 0)), - bitset_filter_(std::make_shared>( - std::move(cuvs::core::bitset(handle_, 0, false)))) + nullptr, 0, 0)) + { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); @@ -210,9 +171,6 @@ class cuvs_cagra : public algo, public algo_gpu { std::shared_ptr> dataset_; std::shared_ptr> input_dataset_v_; - // std::shared_ptr> bitset_filter_; - std::shared_ptr> bitset_filter_; - inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type) { switch (mem_type) { @@ -298,15 +256,6 @@ void cuvs_cagra::set_search_param(const search_param_base& param) need_dataset_update_ = false; } - - { // create bitset filter in advance. - auto stream_ = raft::resource::get_cuda_stream(handle_); - size_t filter_n_elements = size_t((input_dataset_v_->extent(0) + 31) / 32); - bitset_filter_->resize(handle_, input_dataset_v_->extent(0), false); - std::vector bitset_cpu(filter_n_elements); - create_sparse_bitset(input_dataset_v_->extent(0), sparsity, bitset_cpu); - raft::copy(bitset_filter_->data(), bitset_cpu.data(), filter_n_elements, stream_); - } } template @@ -379,15 +328,8 @@ void cuvs_cagra::search_base(const T* queries, raft::make_device_matrix_view(neighbors_idx_t, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - if constexpr ((std::is_same_v || std::is_same_v)&&sparsity >= 0.0f) { - auto filter = cuvs::neighbors::filtering::bitset_filter(bitset_filter_->view()); - cuvs::neighbors::cagra::search( - handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, filter); - - } else { - cuvs::neighbors::cagra::search( - handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - } + cuvs::neighbors::cagra::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) { if (raft::get_device_for_address(neighbors) < 0 && From 1bcba66a5a5b6a31b9c647f25347c7bfd64b4246 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 20 Nov 2024 13:29:37 -0800 Subject: [PATCH 15/15] if (params.threshold_to_bf >= 1.0) { return false; } --- cpp/src/neighbors/detail/cagra/cagra_search.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index ab8ee12cc..ba183fb42 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -152,6 +152,8 @@ bool search_using_brute_force( raft::device_matrix_view distances, CagraSampleFilterT& sample_filter) { + if (params.threshold_to_bf >= 1.0) { return false; }; + auto n_queries = queries.extent(0); auto n_dataset = strided_dataset.n_rows();