-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add filtering for CAGRA to C API #452
base: branch-25.02
Are you sure you want to change the base?
Changes from all commits
59bb8ca
f37c58e
4fd78ca
e36d556
fad44c2
0d0184a
e3d33a0
d0dc961
3b17895
4e584cc
4e2202e
b0e8122
afbcf74
f6b8068
d7725db
9a64945
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
#include <cuvs/core/interop.hpp> | ||
#include <cuvs/neighbors/cagra.h> | ||
#include <cuvs/neighbors/cagra.hpp> | ||
#include <cuvs/neighbors/common.h> | ||
|
||
#include <fstream> | ||
|
||
|
@@ -92,7 +93,8 @@ void _search(cuvsResources_t res, | |
cuvsCagraIndex index, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr); | ||
|
@@ -118,8 +120,26 @@ void _search(cuvsResources_t res, | |
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor); | ||
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor); | ||
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor); | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
if (filter.type == NO_FILTER) { | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
} else if (filter.type == BITSET) { | ||
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the build fails because of the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm actually I missed this. The function you are using here is creating a bitset from a list of indices and I don't think it is the workflow that we expect. |
||
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr); | ||
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor); | ||
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset( | ||
removed_indices, index_ptr->dataset().extent(0)); | ||
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset); | ||
cuvs::neighbors::cagra::search(*res_ptr, | ||
search_params, | ||
*index_ptr, | ||
queries_mds, | ||
neighbors_mds, | ||
distances_mds, | ||
bitset_filter_obj); | ||
} else { | ||
RAFT_FAIL("Unsupported filter type: BITMAP"); | ||
} | ||
} | ||
|
||
template <typename T> | ||
|
@@ -214,7 +234,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
cuvsCagraIndex_t index_c_ptr, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
return cuvs::core::translate_exceptions([=] { | ||
auto queries = queries_tensor->dl_tensor; | ||
|
@@ -237,11 +258,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); | ||
|
||
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { | ||
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<float>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { | ||
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<int8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { | ||
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<uint8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else { | ||
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", | ||
queries.dtype.code, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ from libc.stdint cimport ( | |
) | ||
|
||
from cuvs.common.exceptions import check_cuvs | ||
from cuvs.neighbors.filters import no_filter | ||
|
||
|
||
cdef class CompressionParams: | ||
|
@@ -484,7 +485,8 @@ def search(SearchParams search_params, | |
k, | ||
neighbors=None, | ||
distances=None, | ||
resources=None): | ||
resources=None, | ||
filter=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add this parameter to the python documentation |
||
""" | ||
Find the k nearest neighbors for each query. | ||
|
||
|
@@ -503,6 +505,9 @@ def search(SearchParams search_params, | |
distances : Optional CUDA array interface compliant matrix shape | ||
(n_queries, k) If supplied, the distances to the | ||
neighbors will be written here in-place. (default None) | ||
filter: Optional cuvs.neighbors.cuvsFilter can be used to filter | ||
neighbors based on a given bitset. | ||
(default None) | ||
{resources_docstring} | ||
|
||
Examples | ||
|
@@ -557,6 +562,9 @@ def search(SearchParams search_params, | |
_check_input_array(distances_cai, [np.dtype('float32')], | ||
exp_rows=n_queries, exp_cols=k) | ||
|
||
if filter is None: | ||
filter = no_filter() | ||
|
||
cdef cuvsCagraSearchParams* params = &search_params.params | ||
cdef cydlpack.DLManagedTensor* queries_dlpack = \ | ||
cydlpack.dlpack_c(queries_cai) | ||
|
@@ -573,7 +581,8 @@ def search(SearchParams search_params, | |
index.index, | ||
queries_dlpack, | ||
neighbors_dlpack, | ||
distances_dlpack | ||
distances_dlpack, | ||
filter.prefilter | ||
)) | ||
|
||
return (distances, neighbors) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API change needs to be propagated to:
cuvs/example/c
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done👍