Skip to content

Commit

Permalink
fix the bits type to uint32_t
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 8, 2025
1 parent 4e30bd2 commit cbc5d38
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void _search(cuvsResources_t res,
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>;
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>;

auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
Expand All @@ -88,7 +88,7 @@ void _search(cuvsResources_t res,
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter(
prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(),
prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(),
queries_mds.extent(0),
index_ptr->dataset().extent(0)));
cuvs::neighbors::brute_force::search(
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_prefiltered_brute_force_knn(
index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype)
bitmap = create_sparse_array(
(np.ceil(n_query_rows * n_index_rows / 32).astype(int)), sparsity
(np.ceil(n_query_rows * n_index_rows / 32).astype(np.uint32)), sparsity
)

is_min = metric != "inner_product"
Expand Down

0 comments on commit cbc5d38

Please sign in to comment.