diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 99581469f..fc25092c4 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -350,9 +350,17 @@ auto build(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter An optional device bitmap filter function with a `row-major` layout and - * the shape of [n_queries, index->size()], which means the filter will use the first - * `index->size()` bits to indicate whether queries[0] should compute the distance with dataset. + * @param[in] sample_filter An optional device filter that restricts which dataset elements should + * be considered for each query. + * + * - Supports two types of filters: + * 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`, + * where each bit indicates whether a specific dataset element should be considered for a + * particular query. (1 for inclusion, 0 for exclusion). + * 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element. + * All queries share the same filter, with a logical shape of `[1, index->size()]`. + * + * - The default value is `none_sample_filter`, which applies no filtering. */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::search_params& params, @@ -397,8 +405,17 @@ void search(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a - * given + * @param[in] sample_filter An optional device filter that restricts which dataset elements should + * be considered for each query. + * + * - Supports two types of filters: + * 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`, + * where each bit indicates whether a specific dataset element should be considered for a + * particular query. (1 for inclusion, 0 for exclusion). + * 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element. + * All queries share the same filter, with a logical shape of `[1, index->size()]`. + * + * - The default value is `none_sample_filter`, which applies no filtering. */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::search_params& params, @@ -428,8 +445,17 @@ void search(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a - * given query + * @param[in] sample_filter An optional device filter that restricts which dataset elements should + * be considered for each query. + * + * - Supports two types of filters: + * 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`, + * where each bit indicates whether a specific dataset element should be considered for a + * particular query. (1 for inclusion, 0 for exclusion). + * 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element. + * All queries share the same filter, with a logical shape of `[1, index->size()]`. + * + * - The default value is `none_sample_filter`, which applies no filtering. */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::search_params& params, @@ -459,8 +485,17 @@ void search(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a - * given query + * @param[in] sample_filter An optional device filter that restricts which dataset elements should + * be considered for each query. + * + * - Supports two types of filters: + * 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`, + * where each bit indicates whether a specific dataset element should be considered for a + * particular query. (1 for inclusion, 0 for exclusion). + * 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element. + * All queries share the same filter, with a logical shape of `[1, index->size()]`. + * + * - The default value is `none_sample_filter`, which applies no filtering. */ void search(raft::resources const& handle, const cuvs::neighbors::brute_force::search_params& params, diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index bd9ea4834..5dc99a4e8 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; namespace filtering { +enum class FilterType { None, Bitmap, Bitset }; + struct base_filter { - virtual ~base_filter() = default; + virtual ~base_filter() = default; + virtual FilterType get_filter_type() const = 0; }; /* A filter that filters nothing. This is the default behavior. */ @@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter { const uint32_t query_ix, // the index of the current sample const uint32_t sample_ix) const; + + FilterType get_filter_type() const override { return FilterType::None; } }; /** @@ -513,15 +519,24 @@ struct ivf_to_sample_filter { */ template struct bitmap_filter : public base_filter { + using view_t = cuvs::core::bitmap_view; + // View of the bitset to use as a filter - const cuvs::core::bitmap_view bitmap_view_; + const view_t bitmap_view_; - bitmap_filter(const cuvs::core::bitmap_view bitmap_for_filtering); + bitmap_filter(const view_t bitmap_for_filtering); inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, // the index of the current sample const uint32_t sample_ix) const; + + FilterType get_filter_type() const override { return FilterType::Bitmap; } + + view_t view() const { return bitmap_view_; } + + template + void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; /** @@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter { */ template struct bitset_filter : public base_filter { + using view_t = cuvs::core::bitset_view; + // View of the bitset to use as a filter - const cuvs::core::bitset_view bitset_view_; + const view_t bitset_view_; - bitset_filter(const cuvs::core::bitset_view bitset_for_filtering); + bitset_filter(const view_t bitset_for_filtering); inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, // the index of the current sample const uint32_t sample_ix) const; + + FilterType get_filter_type() const override { return FilterType::Bitset; } + + view_t view() const { return bitset_view_; } + + template + void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; /** diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index 1693ac930..98c74e285 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -67,8 +67,8 @@ void _search(cuvsResources_t res, using queries_mdspan_type = raft::device_matrix_view; using neighbors_mdspan_type = raft::device_matrix_view; using distances_mdspan_type = raft::device_matrix_view; - using prefilter_mds_type = raft::device_vector_view; - using prefilter_bmp_type = cuvs::core::bitmap_view; + using prefilter_mds_type = raft::device_vector_view; + using prefilter_bmp_type = cuvs::core::bitmap_view; auto queries_mds = cuvs::core::from_dlpack(queries_tensor); auto neighbors_mds = cuvs::core::from_dlpack(neighbors_tensor); @@ -85,14 +85,14 @@ void _search(cuvsResources_t res, distances_mds, cuvs::neighbors::filtering::none_sample_filter{}); } else if (prefilter.type == BITMAP) { - auto prefilter_ptr = reinterpret_cast(prefilter.addr); - auto prefilter_mds = cuvs::core::from_dlpack(prefilter_ptr); - auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter( - prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(), + auto prefilter_ptr = reinterpret_cast(prefilter.addr); + auto prefilter_mds = cuvs::core::from_dlpack(prefilter_ptr); + const auto prefilter = cuvs::neighbors::filtering::bitmap_filter( + prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(), queries_mds.extent(0), index_ptr->dataset().extent(0))); cuvs::neighbors::brute_force::search( - *res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view); + *res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter); } else { RAFT_FAIL("Unsupported prefilter type: BITSET"); } diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index f1976e002..f8eee22f1 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -56,9 +56,13 @@ #include #include +#include #include +#include namespace cuvs::neighbors::detail { + +using namespace cuvs::neighbors::filtering; /** * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances @@ -82,8 +86,9 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_col_tile_size = 0, const DistanceT* precomputed_index_norms = nullptr, const DistanceT* precomputed_search_norms = nullptr, - const uint32_t* filter_bitmap = nullptr, - DistanceEpilogue distance_epilogue = raft::identity_op()) + const uint32_t* filter_bits = nullptr, + DistanceEpilogue distance_epilogue = raft::identity_op(), + FilterType filter_type = FilterType::Bitmap) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -245,21 +250,23 @@ void tiled_brute_force_knn(const raft::resources& handle, } } - if (filter_bitmap != nullptr) { - auto distances_ptr = temp_distances.data(); - auto count = thrust::make_counting_iterator(0); - DistanceT masked_distance = select_min ? std::numeric_limits::infinity() - : std::numeric_limits::lowest(); + auto distances_ptr = temp_distances.data(); + auto count = thrust::make_counting_iterator(0); + DistanceT masked_distance = select_min ? std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + + if (filter_bits != nullptr) { + size_t n_cols = filter_type == FilterType::Bitmap ? n : 0; thrust::for_each(raft::resource::get_thrust_policy(handle), count, count + current_query_size * current_centroid_size, [=] __device__(IndexType idx) { IndexType row = i + (idx / current_centroid_size); IndexType col = j + (idx % current_centroid_size); - IndexType g_idx = row * n + col; + IndexType g_idx = row * n_cols + col; IndexType item_idx = (g_idx) >> 5; uint32_t bit_idx = (g_idx)&31; - uint32_t filter = filter_bitmap[item_idx]; + uint32_t filter = filter_bits[item_idx]; if ((filter & (uint32_t(1) << bit_idx)) == 0) { distances_ptr[idx] = masked_distance; } @@ -575,12 +582,12 @@ void brute_force_search( query_norms ? query_norms->data_handle() : nullptr); } -template +template void brute_force_search_filtered( raft::resources const& res, const cuvs::neighbors::brute_force::index& idx, raft::device_matrix_view queries, - cuvs::core::bitmap_view filter, + const base_filter* filter, raft::device_matrix_view neighbors, raft::device_matrix_view distances, std::optional> query_norms = std::nullopt) @@ -601,29 +608,40 @@ void brute_force_search_filtered( metric == cuvs::distance::DistanceType::CosineExpanded), "Index must has norms when using Euclidean, IP, and Cosine!"); - IdxT n_queries = queries.extent(0); - IdxT n_dataset = idx.dataset().extent(0); - IdxT dim = idx.dataset().extent(1); - IdxT k = neighbors.extent(1); + IdxT n_queries = queries.extent(0); + IdxT n_dataset = idx.dataset().extent(0); + IdxT dim = idx.dataset().extent(1); + IdxT k = neighbors.extent(1); + FilterType filter_type = filter->get_filter_type(); auto stream = raft::resource::get_cuda_stream(res); - // calc nnz - IdxT nnz_h = 0; - rmm::device_scalar nnz(0, stream); - auto nnz_view = raft::make_device_scalar_view(nnz.data()); - auto filter_view = - raft::make_device_vector_view(filter.data(), filter.n_elements()); - IdxT size_h = n_queries * n_dataset; - auto size_view = raft::make_host_scalar_view(&size_h); - - raft::popc(res, filter_view, size_view, nnz_view); - raft::copy(&nnz_h, nnz.data(), 1, stream); + std::optional, + const cuvs::core::bitset_view>> + filter_view; + + IdxT nnz_h = 0; + float sparsity = 0.0f; + + const BitsT* filter_data = nullptr; + + if (filter_type == FilterType::Bitmap) { + auto actual_filter = dynamic_cast*>(filter); + filter_view.emplace(actual_filter->view()); + nnz_h = actual_filter->view().count(res); + sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + } else if (filter_type == FilterType::Bitset) { + auto actual_filter = dynamic_cast*>(filter); + filter_view.emplace(actual_filter->view()); + nnz_h = n_queries * actual_filter->view().count(res); + sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + } else { + RAFT_FAIL("Unsupported sample filter type"); + } - raft::resource::sync_stream(res, stream); - float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); + std::visit([&](const auto& actual_view) { filter_data = actual_view.data(); }, *filter_view); - if (sparsity > 0.01f) { + if (sparsity < 0.9f) { raft::resources stream_pool_handle(res); raft::resource::set_cuda_stream(stream_pool_handle, stream); auto idx_norm = idx.has_norms() ? const_cast(idx.norms().data_handle()) : nullptr; @@ -643,12 +661,12 @@ void brute_force_search_filtered( 0, idx_norm, nullptr, - filter.data()); + filter_data, + raft::identity_op(), + filter_type); } else { auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); - - // fill csr - raft::sparse::convert::bitmap_to_csr(res, filter, csr); + std::visit([&](const auto& actual_view) { actual_view.to_csr(res, csr); }, *filter_view); // create filter csr view auto compressed_csr_view = csr.structure_view(); @@ -664,7 +682,11 @@ void brute_force_search_filtered( auto csr_view = raft::make_device_csr_matrix_view( csr.get_elements().data(), compressed_csr_view); - raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view); + std::visit( + [&](const auto& actual_view) { + raft::sparse::linalg::masked_matmul(res, queries, dataset_view, actual_view, csr_view); + }, + *filter_view); // post process std::optional> query_norms_; @@ -725,29 +747,32 @@ void search(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - const cuvs::neighbors::filtering::base_filter& sample_filter_ref) + const base_filter& sample_filter_ref) { try { - auto& sample_filter = - dynamic_cast(sample_filter_ref); + auto& sample_filter = dynamic_cast(sample_filter_ref); return brute_force_search(res, idx, queries, neighbors, distances); } catch (const std::bad_cast&) { } + if constexpr (std::is_same_v) { + RAFT_FAIL("filtered search isn't available with col_major queries yet"); + } else { + try { + auto& sample_filter = + dynamic_cast&>(sample_filter_ref); + return brute_force_search_filtered( + res, idx, queries, &sample_filter, neighbors, distances); + } catch (const std::bad_cast&) { + } - try { - auto& sample_filter = - dynamic_cast&>( - sample_filter_ref); - if constexpr (std::is_same_v) { - RAFT_FAIL("filtered search isn't available with col_major queries yet"); - } else { - cuvs::core::bitmap_view sample_filter_view = - sample_filter.bitmap_view_; + try { + auto& sample_filter = + dynamic_cast&>(sample_filter_ref); return brute_force_search_filtered( - res, idx, queries, sample_filter_view, neighbors, distances); + res, idx, queries, &sample_filter, neighbors, distances); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); } - } catch (const std::bad_cast&) { - RAFT_FAIL("Unsupported sample filter type"); } } diff --git a/cpp/src/neighbors/sample_filter.cuh b/cpp/src/neighbors/sample_filter.cuh index 258116ed3..b0c61f924 100644 --- a/cpp/src/neighbors/sample_filter.cuh +++ b/cpp/src/neighbors/sample_filter.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -108,6 +109,13 @@ inline _RAFT_HOST_DEVICE bool bitset_filter::operator()( return bitset_view_.test(sample_ix); } +template +template +void bitset_filter::to_csr(raft::resources const& handle, csr_matrix_t& csr) +{ + raft::sparse::convert::bitset_to_csr(handle, bitset_view_, csr); +} + template bitmap_filter::bitmap_filter( const cuvs::core::bitmap_view bitmap_for_filtering) @@ -124,4 +132,12 @@ inline _RAFT_HOST_DEVICE bool bitmap_filter::operator()( { return bitmap_view_.test(query_ix, sample_ix); } + +template +template +void bitmap_filter::to_csr(raft::resources const& handle, csr_matrix_t& csr) +{ + raft::sparse::convert::bitmap_to_csr(handle, bitmap_view_, csr); +} + } // namespace cuvs::neighbors::filtering diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 12b1c529e..bf7dce7ee 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -146,11 +147,27 @@ void set_bitmap(const index_t* src, RAFT_CUDA_TRY(cudaGetLastError()); } +bool isCuSparseVersionGreaterThan_12_0_1() +{ + int version; + cusparseHandle_t handle; + cusparseCreate(&handle); + cusparseGetVersion(handle, &version); + + int major = version / 1000; + int minor = (version % 1000) / 100; + int patch = version % 100; + + cusparseDestroy(handle); + + return (major > 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2); +} + template -class PrefilteredBruteForceTest +class PrefilteredBruteForceOnBitmapTest : public ::testing::TestWithParam> { public: - PrefilteredBruteForceTest() + PrefilteredBruteForceOnBitmapTest() : stream(raft::resource::get_cuda_stream(handle)), params(::testing::TestWithParam>::GetParam()), filter_d(0, stream), @@ -352,6 +369,9 @@ class PrefilteredBruteForceTest void SetUp() override { + if (std::is_same_v && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } index_t element = raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8)); std::vector filter_h(element); @@ -476,8 +496,6 @@ class PrefilteredBruteForceTest out_val_expected_d.resize(params.n_queries * params.top_k, stream); out_idx_expected_d.resize(params.n_queries * params.top_k, stream); - // dump_vector(out_val_h.data(), out_val_h.size(), "out_val_h"); - raft::update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); raft::update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); @@ -494,8 +512,8 @@ class PrefilteredBruteForceTest auto dataset = brute_force::build(handle, dataset_raw, params.metric); - auto filter = cuvs::core::bitmap_view( - (const bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); + auto filter = cuvs::core::bitmap_view( + (bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); auto out_val = raft::make_device_matrix_view( out_val_d.data(), params.n_queries, params.top_k); @@ -544,11 +562,451 @@ class PrefilteredBruteForceTest rmm::device_uvector out_idx_expected_d; }; -using PrefilteredBruteForceTest_float_int64 = PrefilteredBruteForceTest; -TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); } +template +class PrefilteredBruteForceOnBitsetTest + : public ::testing::TestWithParam> { + public: + PrefilteredBruteForceOnBitsetTest() + : stream(raft::resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + filter_d(0, stream), + dataset_d(0, stream), + queries_d(0, stream), + out_val_d(0, stream), + out_val_expected_d(0, stream), + out_idx_d(0, stream), + out_idx_expected_d(0, stream) + { + } + + protected: + void repeat_cpu_bitset(std::vector& input, + size_t input_bits, + size_t repeat, + std::vector& output) + { + const size_t output_bits = input_bits * repeat; + const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); + + std::memset(output.data(), 0, output_units * sizeof(bitset_t)); + + size_t output_bit_index = 0; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bitset_t) * 8); + size_t input_bit_offset = i % (sizeof(bitset_t) * 8); + bool bit = (input[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8); + + output[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } + } + + index_t create_sparse_matrix_with_rmat(index_t m, + index_t n, + float sparsity, + rmm::device_uvector& filter_d) + { + index_t r_scale = (index_t)std::log2(m); + index_t c_scale = (index_t)std::log2(n); + index_t n_edges = (index_t)(m * n * 1.0f * sparsity); + index_t max_scale = std::max(r_scale, c_scale); + + rmm::device_uvector out_src{(unsigned long)n_edges, stream}; + rmm::device_uvector out_dst{(unsigned long)n_edges, stream}; + rmm::device_uvector theta{(unsigned long)(4 * max_scale), stream}; + + raft::random::RngState state{2024ULL, raft::random::GeneratorType::GenPC}; + + raft::random::uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); + normalize( + theta.data(), theta.data(), max_scale, r_scale, c_scale, r_scale != c_scale, true, stream); + raft::random::rmat_rectangular_gen((index_t*)nullptr, + out_src.data(), + out_dst.data(), + theta.data(), + r_scale, + c_scale, + n_edges, + stream, + state); + + index_t nnz_h = 0; + { + auto src = out_src.data(); + auto dst = out_dst.data(); + auto bitset = filter_d.data(); + rmm::device_scalar nnz(0, stream); + auto nnz_view = raft::make_device_scalar_view(nnz.data()); + auto filter_view = + raft::make_device_vector_view(filter_d.data(), filter_d.size()); + index_t size_h = m * n; + auto size_view = raft::make_host_scalar_view(&size_h); + + set_bitmap(src, dst, bitset, n_edges, n, stream); + + raft::popc(handle, filter_view, size_view, nnz_view); + raft::copy(&nnz_h, nnz.data(), 1, stream); + + raft::resource::sync_stream(handle, stream); + } + + return nnz_h; + } + + void cpu_convert_to_csr(std::vector& bitset, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitset_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitset[index / (8 * sizeof(bitset_t))]; + bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B, + dist_t alpha = 1.0, + dist_t beta = 0.0) + { + if (params.n_queries * params.dim != static_cast(A.size()) || + params.dim * params.n_dataset != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } -using PrefilteredBruteForceTest_half_int64 = PrefilteredBruteForceTest; -TEST_P(PrefilteredBruteForceTest_half_int64, Result) { Run(); } + bool trans_a = is_row_major_A; + bool trans_b = is_row_major_B; + + for (index_t i = 0; i < params.n_queries; ++i) { + for (index_t j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + dist_t sum = 0; + dist_t norms_A = 0; + dist_t norms_B = 0; + + for (index_t l = 0; l < params.dim; ++l) { + index_t a_index = trans_a ? i * params.dim + l : l * params.n_queries + i; + index_t b_index = trans_b ? l * params.n_dataset + cols[j] : cols[j] * params.dim + l; + dist_t A_v; + dist_t B_v; + if constexpr (sizeof(value_t) == 2) { + A_v = __half2float(__float2half(A[a_index])); + B_v = __half2float(__float2half(B[b_index])); + } else { + A_v = A[a_index]; + B_v = B[b_index]; + } + + sum += A_v * B_v; + + norms_A += A_v * A_v; + norms_B += B_v * B_v; + } + vals[j] = alpha * sum + beta * vals[j]; + if (params.metric == cuvs::distance::DistanceType::L2Expanded) { + vals[j] = dist_t(-2.0) * vals[j] + norms_A + norms_B; + } else if (params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + vals[j] = std::sqrt(dist_t(-2.0) * vals[j] + norms_A + norms_B); + } else if (params.metric == cuvs::distance::DistanceType::CosineExpanded) { + vals[j] = dist_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B); + } + } + } + } + + void cpu_select_k(const std::vector& indptr_h, + const std::vector& indices_h, + const std::vector& values_h, + std::optional>& in_idx_h, + index_t n_queries, + index_t n_dataset, + index_t top_k, + std::vector& out_values_h, + std::vector& out_indices_h, + bool select_min = true) + { + auto comp = [select_min](const std::pair& a, + const std::pair& b) { + return select_min ? a.first < b.first : a.first >= b.first; + }; + + for (index_t row = 0; row < n_queries; ++row) { + std::priority_queue, + std::vector>, + decltype(comp)> + pq(comp); + for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { + pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); + if (pq.size() > size_t(top_k)) { pq.pop(); } + } + + std::vector> row_pairs; + while (!pq.empty()) { + row_pairs.push_back(pq.top()); + pq.pop(); + } + + if (select_min) { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first <= b.first; + }); + } else { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first >= b.first; + }); + } + for (index_t col = 0; col < top_k; col++) { + if (col < index_t(row_pairs.size())) { + out_values_h[row * top_k + col] = row_pairs[col].first; + out_indices_h[row * top_k + col] = row_pairs[col].second; + } + } + } + } + + void SetUp() override + { + if (std::is_same_v && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } + index_t element = raft::ceildiv(1 * params.n_dataset, index_t(sizeof(bitset_t) * 8)); + std::vector filter_h(element); + std::vector filter_repeat_h(element * params.n_queries); + + filter_d.resize(element, stream); + + nnz = create_sparse_matrix_with_rmat(1, params.n_dataset, params.sparsity, filter_d); + raft::update_host(filter_h.data(), filter_d.data(), filter_d.size(), stream); + raft::resource::sync_stream(handle, stream); + + repeat_cpu_bitset( + filter_h, size_t(params.n_dataset), size_t(params.n_queries), filter_repeat_h); + nnz *= params.n_queries; + + index_t dataset_size = params.n_dataset * params.dim; + index_t queries_size = params.n_queries * params.dim; + + std::vector dataset_h(dataset_size); + std::vector queries_h(queries_size); + + dataset_d.resize(dataset_size, stream); + queries_d.resize(queries_size, stream); + + auto blobs_in_val = + raft::make_device_matrix(handle, 1, dataset_size + queries_size); + auto labels = raft::make_device_vector(handle, 1); + + if constexpr (!std::is_same_v) { + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + } else { + raft::random::make_blobs(blobs_in_val.data_handle(), + labels.data_handle(), + 1, + dataset_size + queries_size, + 1, + stream, + false, + nullptr, + nullptr, + dist_t(1.0), + false, + dist_t(-1.0f), + dist_t(1.0f), + uint64_t(2024)); + } + + raft::copy(dataset_h.data(), blobs_in_val.data_handle(), dataset_size, stream); + + if constexpr (std::is_same_v) { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_in_val.data_handle()); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(dataset_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + dataset_size, + d_value_ptr, + float_to_half()); + } else { + raft::copy(dataset_d.data(), blobs_in_val.data_handle(), dataset_size, stream); + } + + raft::copy(queries_h.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + if constexpr (std::is_same_v) { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_in_val.data_handle() + dataset_size); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(queries_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + queries_size, + d_value_ptr, + float_to_half()); + } else { + raft::copy(queries_d.data(), blobs_in_val.data_handle() + dataset_size, queries_size, stream); + } + + raft::resource::sync_stream(handle); + + std::vector values_h(nnz); + std::vector indices_h(nnz); + std::vector indptr_h(params.n_queries + 1); + + cpu_convert_to_csr(filter_repeat_h, params.n_queries, params.n_dataset, indices_h, indptr_h); + + cpu_sddmm(queries_h, dataset_h, values_h, indices_h, indptr_h, true, false); + + bool select_min = cuvs::distance::is_min_close(params.metric); + + std::vector out_val_h( + params.n_queries * params.top_k, + select_min ? std::numeric_limits::infinity() : std::numeric_limits::lowest()); + std::vector out_idx_h(params.n_queries * params.top_k, static_cast(0)); + + out_val_d.resize(params.n_queries * params.top_k, stream); + out_idx_d.resize(params.n_queries * params.top_k, stream); + + raft::update_device(out_val_d.data(), out_val_h.data(), out_val_h.size(), stream); + raft::update_device(out_idx_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + + raft::resource::sync_stream(handle); + + std::optional> optional_indices_h = std::nullopt; + cpu_select_k(indptr_h, + indices_h, + values_h, + optional_indices_h, + params.n_queries, + params.n_dataset, + params.top_k, + out_val_h, + out_idx_h, + select_min); + out_val_expected_d.resize(params.n_queries * params.top_k, stream); + out_idx_expected_d.resize(params.n_queries * params.top_k, stream); + + raft::update_device(out_val_expected_d.data(), out_val_h.data(), out_val_h.size(), stream); + raft::update_device(out_idx_expected_d.data(), out_idx_h.data(), out_idx_h.size(), stream); + + raft::resource::sync_stream(handle); + } + + void Run() + { + auto dataset_raw = raft::make_device_matrix_view( + (const value_t*)dataset_d.data(), params.n_dataset, params.dim); + + auto queries = raft::make_device_matrix_view( + (const value_t*)queries_d.data(), params.n_queries, params.dim); + + auto dataset = brute_force::build(handle, dataset_raw, params.metric); + + auto filter = + cuvs::core::bitset_view((bitset_t*)filter_d.data(), params.n_dataset); + + auto out_val = raft::make_device_matrix_view( + out_val_d.data(), params.n_queries, params.top_k); + auto out_idx = raft::make_device_matrix_view( + out_idx_d.data(), params.n_queries, params.top_k); + + brute_force::search(handle, + dataset, + queries, + out_idx, + out_val, + cuvs::neighbors::filtering::bitset_filter(filter)); + std::vector out_val_h(params.n_queries * params.top_k, + std::numeric_limits::infinity()); + + raft::update_host(out_val_h.data(), out_val_d.data(), out_val_h.size(), stream); + raft::resource::sync_stream(handle); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(), + out_idx.data_handle(), + out_val_expected_d.data(), + out_val.data_handle(), + params.n_queries, + params.top_k, + 0.001f, + stream, + true)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + PrefilteredBruteForceInputs params; + + index_t nnz; + + rmm::device_uvector dataset_d; + rmm::device_uvector queries_d; + rmm::device_uvector filter_d; + + rmm::device_uvector out_val_d; + rmm::device_uvector out_val_expected_d; + + rmm::device_uvector out_idx_d; + rmm::device_uvector out_idx_expected_d; +}; + +using PrefilteredBruteForceTestOnBitmap_float_int64 = + PrefilteredBruteForceOnBitmapTest; +TEST_P(PrefilteredBruteForceTestOnBitmap_float_int64, Result) { Run(); } + +using PrefilteredBruteForceTestOnBitmap_half_int64 = + PrefilteredBruteForceOnBitmapTest; +TEST_P(PrefilteredBruteForceTestOnBitmap_half_int64, Result) { Run(); } + +using PrefilteredBruteForceTestOnBitset_float_int64 = + PrefilteredBruteForceOnBitsetTest; +TEST_P(PrefilteredBruteForceTestOnBitset_float_int64, Result) { Run(); } + +using PrefilteredBruteForceTestOnBitset_half_int64 = + PrefilteredBruteForceOnBitsetTest; +TEST_P(PrefilteredBruteForceTestOnBitset_half_int64, Result) { Run(); } template const std::vector> selectk_inputs = { @@ -570,7 +1028,7 @@ const std::vector> selectk_inputs = { {1024, 8192, 5, 0, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 8, 0, 0.1, cuvs::distance::DistanceType::CosineExpanded}, - {1024, 8192, 1, 1, 0.1, cuvs::distance::DistanceType::L2Expanded}, //-- + {1024, 8192, 1, 1, 0.1, cuvs::distance::DistanceType::L2Expanded}, {1024, 8192, 3, 1, 0.1, cuvs::distance::DistanceType::InnerProduct}, {1024, 8192, 5, 1, 0.1, cuvs::distance::DistanceType::L2SqrtExpanded}, {1024, 8192, 8, 1, 0.1, cuvs::distance::DistanceType::CosineExpanded}, @@ -599,12 +1057,20 @@ const std::vector> selectk_inputs = { {1024, 8192, 5, 16, 0.5, cuvs::distance::DistanceType::CosineExpanded}, {1024, 8192, 8, 16, 0.2, cuvs::distance::DistanceType::CosineExpanded}}; -INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, - PrefilteredBruteForceTest_float_int64, +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceOnBitmapTest, + PrefilteredBruteForceTestOnBitmap_float_int64, + ::testing::ValuesIn(selectk_inputs)); + +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceOnBitmapTest, + PrefilteredBruteForceTestOnBitmap_half_int64, + ::testing::ValuesIn(selectk_inputs)); + +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceOnBitsetTest, + PrefilteredBruteForceTestOnBitset_float_int64, ::testing::ValuesIn(selectk_inputs)); -INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceTest, - PrefilteredBruteForceTest_half_int64, +INSTANTIATE_TEST_CASE_P(PrefilteredBruteForceOnBitsetTest, + PrefilteredBruteForceTestOnBitset_half_int64, ::testing::ValuesIn(selectk_inputs)); } // namespace cuvs::neighbors::brute_force diff --git a/python/cuvs/cuvs/test/test_brute_force.py b/python/cuvs/cuvs/test/test_brute_force.py index 0b37ad885..a234794f9 100644 --- a/python/cuvs/cuvs/test/test_brute_force.py +++ b/python/cuvs/cuvs/test/test_brute_force.py @@ -134,7 +134,7 @@ def test_prefiltered_brute_force_knn( index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype) queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype) bitmap = create_sparse_array( - (np.ceil(n_query_rows * n_index_rows / 32).astype(int)), sparsity + (np.ceil(n_query_rows * n_index_rows / 32).astype(np.uint32)), sparsity ) is_min = metric != "inner_product"