diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index e48050756..2150f4214 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -229,6 +229,10 @@ struct search_params : cuvs::neighbors::search_params { * impact on the throughput. */ float persistent_device_usage = 1.0; + + /** A sparsity threshold; brute force is used when sparsity exceeds this threshold, in the range + * [0, 1] */ + double threshold_to_bf = 0.9; }; /** diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5778d85a6..ba183fb42 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -21,14 +21,17 @@ #include "sample_filter_utils.cuh" #include "search_plan.cuh" #include "search_single_cta_inst.cuh" +#include "utils.hpp" #include #include #include #include +#include #include +#include #include // TODO: Fix these when ivf methods are moved over @@ -108,6 +111,114 @@ void search_main_core(raft::resources const& res, } } +/** + * @brief Performs ANN search using brute force when filter sparsity exceeds a specified threshold. + * + * This function switches to a brute force search approach to improve recall rate when the + * `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level + * (proportion of unfiltered samples) exceeding the specified `params.threshold_to_bf`. + * + * @tparam T data element type + * @tparam IdxT type of database vector indices + * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need + * separate kernels for int/uint. + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] strided_dataset CAGRA strided dataset + * @param[in] metric distance type + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query + * + * @return true If the brute force search was applied successfully. + * @return false If the brute force search was not applied. + */ +template +bool search_using_brute_force( + raft::resources const& res, + const search_params& params, + const strided_dataset& strided_dataset, + cuvs::distance::DistanceType metric, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT& sample_filter) +{ + if (params.threshold_to_bf >= 1.0) { return false; }; + + auto n_queries = queries.extent(0); + auto n_dataset = strided_dataset.n_rows(); + + auto bitset_filter_view = sample_filter.bitset_view_; + auto sparsity = bitset_filter_view.sparsity(res); + + if (sparsity < params.threshold_to_bf) { return false; } + + // TODO: Support host dataset in `brute_force::build` + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity); + using bitmap_view_t = cuvs::core::bitmap_view; + + auto stream = raft::resource::get_cuda_stream(res); + auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); + + rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); + rmm::device_uvector raw_neighbors(neighbors.size(), stream); + + bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); + + auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); + + auto brute_force_neighbors = raft::make_device_matrix_view( + raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); + auto brute_force_dataset = raft::make_device_matrix_view( + strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride()); + + auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric); + + auto brute_force_queries = queries; + auto padding_queries = raft::make_device_matrix(res, 0, 0); + + // Happens when the original dataset is a strided matrix. + if (brute_force_dataset.extent(1) != queries.extent(1)) { + padding_queries = raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(n_queries, brute_force_dataset.extent(1))); + // Copy the queries and fill the padded elements with zeros + raft::linalg::map_offset( + res, + padding_queries.view(), + [queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) { + auto row_ix = i / stride; + auto el_ix = i % stride; + return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0}; + }); + brute_force_queries = raft::make_device_matrix_view( + padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); + } + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + brute_force_queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + return true; +} + /** * @brief Search ANN using the constructed index. * @@ -126,6 +237,7 @@ void search_main_core(raft::resources const& res, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template *>(&index.data()); strided_dset != nullptr) { + if constexpr (!std::is_same_v && + (std::is_same_v || std::is_same_v)) { + bool bf_search_done = search_using_brute_force( + res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); + if (bf_search_done) return; + } + // Search using a plain (strided) row-major dataset auto desc = dataset_descriptor_init_with_cache( res, params, *strided_dset, index.metric()); diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index e5eeecbc9..24fe0651a 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -620,9 +620,9 @@ void brute_force_search_filtered( raft::copy(&nnz_h, nnz.data(), 1, stream); raft::resource::sync_stream(res, stream); - float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); + float sparsity = (1.0f - (1.0f * nnz_h) / (1.0f * n_queries * n_dataset)); - if (sparsity > 0.01f) { + if (sparsity < 0.9f) { raft::resources stream_pool_handle(res); raft::resource::set_cuda_stream(stream_pool_handle, stream); auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 660246c67..07f2e81c8 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra { namespace { struct test_cagra_sample_filter { - static constexpr unsigned offset = 300; inline _RAFT_HOST_DEVICE auto operator()( // query index const uint32_t query_ix, // the index of the current sample inside the current inverted list - const uint32_t sample_ix) const + const uint32_t sample_ix, + const uint32_t offset) const { return sample_ix >= offset; } @@ -95,6 +95,39 @@ void RandomSuffle(raft::host_matrix_view index) } } +template +void copy_with_padding( + raft::resources const& res, + raft::device_matrix& dst, + raft::mdspan, raft::row_major, data_accessor> src, + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) +{ + size_t padded_dim = raft::round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); + + if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { + // clear existing memory before allocating to prevent OOM errors on large datasets + if (dst.size()) { dst = raft::make_device_matrix(res, 0, 0); } + dst = + raft::make_device_mdarray(res, mr, raft::make_extents(src.extent(0), padded_dim)); + } + if (dst.extent(1) == src.extent(1)) { + raft::copy( + dst.data_handle(), src.data_handle(), src.size(), raft::resource::get_cuda_stream(res)); + } else { + // copy with padding + RAFT_CUDA_TRY(cudaMemsetAsync( + dst.data_handle(), 0, dst.size() * sizeof(T), raft::resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(T) * dst.extent(1), + src.data_handle(), + sizeof(T) * src.extent(1), + sizeof(T) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + } +} + template testing::AssertionResult CheckOrder(raft::host_matrix_view index_test, raft::host_matrix_view dataset) @@ -276,6 +309,8 @@ struct AnnCagraInputs { bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; + double threshold_to_bf = 0.9; + uint32_t filter_offset = 300; std::optional ivf_pq_search_refine_ratio = std::nullopt; std::optional compression = std::nullopt; @@ -702,21 +737,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; - cuvs::neighbors::naive_knn( - handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.n_queries, - ps.n_rows - test_cagra_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); + auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim; + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - ps.filter_offset, + ps.dim, + ps.k, + ps.metric); raft::linalg::addScalar(indices_naive_dev.data(), indices_naive_dev.data(), - IdxT(test_cagra_sample_filter::offset), + IdxT(ps.filter_offset), queries_size, stream_); raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); @@ -755,9 +789,10 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { index_params.compression = ps.compression; cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.threshold_to_bf = ps.threshold_to_bf; // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for // k>1024 skip these tests until fixed @@ -780,6 +815,17 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + auto dataset_padding = raft::make_device_matrix(handle_, 0, 0); + if ((sizeof(DataT) * ps.dim % 16) != 0) { + copy_with_padding(handle_, dataset_padding, database_view); + auto database_view = raft::make_device_strided_matrix_view( + dataset_padding.data_handle(), + dataset_padding.extent(0), + ps.dim, + dataset_padding.extent(1)); + index.update_dataset(handle_, database_view); + } + auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.n_queries, ps.dim); auto indices_out_view = @@ -787,7 +833,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto dists_out_view = raft::make_device_matrix_view( distances_dev.data(), ps.n_queries, ps.k); auto removed_indices = - raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + raft::make_device_vector(handle_, ps.filter_offset); thrust::sequence( raft::resource::get_thrust_policy(handle_), thrust::device_pointer_cast(removed_indices.data_handle()), @@ -813,8 +859,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { bool unacceptable_node = false; for (int q = 0; q < ps.n_queries; q++) { for (int i = 0; i < ps.k; i++) { - const auto n = indices_Cagra[q * ps.k + i]; - unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = + unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset); } } EXPECT_FALSE(unacceptable_node); @@ -1002,6 +1049,8 @@ inline std::vector generate_inputs() {false, true}, {false}, {0.99}, + {0.9}, + {uint32_t(300)}, {1.0f, 2.0f, 3.0f}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -1028,6 +1077,36 @@ inline std::vector generate_inputs() return inputs; } -const std::vector inputs = generate_inputs(); +inline std::vector generate_bf_inputs() +{ + // Add test cases for brute force as sparsity >= 0.9. + std::vector inputs_for_brute_force; + auto inputs_original = raft::util::itertools::product( + {100}, + {1000}, + {1, 7, 8, 17}, + {1, 16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + {0, 1, 10, 100}, + {0}, + {256}, + {1}, + {cuvs::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {1.0}, + {0.1, 0.4, 0.8}); + for (auto input : inputs_original) { + input.filter_offset = 0.5 * input.n_rows; + input.min_recall = input.threshold_to_bf <= 0.5 ? 1.0 : 0.6; + inputs_for_brute_force.push_back(input); + } + + return inputs_for_brute_force; +} + +const std::vector inputs = generate_inputs(); +const std::vector inputs_brute_force = generate_bf_inputs(); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index ca188d132..a98c31510 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest, + AnnCagraFilterTestF_U32, + ::testing::ValuesIn(inputs_brute_force)); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu index f03de69d2..521e12a03 100644 --- a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu @@ -23,6 +23,13 @@ namespace cuvs::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF16_U32; TEST_P(AnnCagraTestF16_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraFilterTest AnnCagraFilterTestF16_U32; +TEST_P(AnnCagraFilterTestF16_U32, AnnCagra) { this->testCagra(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF16_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF16_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest, + AnnCagraFilterTestF16_U32, + ::testing::ValuesIn(inputs_brute_force)); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 12b1c529e..7257dcc9e 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -76,6 +76,21 @@ struct CompareApproxWithInf { T eps; }; +bool isCuSparseVersionGreaterThan_12_0_1() +{ + int version; + cusparseHandle_t handle; + cusparseCreate(&handle); + cusparseGetVersion(handle, &version); + + int major = version / 1000; + int minor = (version % 1000) / 100; + int patch = version % 100; + + cusparseDestroy(handle); + + return (major > 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2); +} template RAFT_KERNEL normalize_kernel( OutT* theta, const InT* in_vals, size_t max_scale, size_t r_scale, size_t c_scale) @@ -352,6 +367,9 @@ class PrefilteredBruteForceTest void SetUp() override { + if (std::is_same_v && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } index_t element = raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); std::vector filter_h(element);