-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add filtering for CAGRA to C API #452
Changes from 8 commits
59bb8ca
f37c58e
4fd78ca
e36d556
fad44c2
0d0184a
e3d33a0
d0dc961
3b17895
4e584cc
4e2202e
b0e8122
afbcf74
f6b8068
d7725db
9a64945
8ac5b7f
9d8ee2e
3809c44
4023d89
7e2b803
a307c57
3ad5dc6
7159689
8331656
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
#include <cuvs/core/interop.hpp> | ||
#include <cuvs/neighbors/cagra.h> | ||
#include <cuvs/neighbors/cagra.hpp> | ||
#include <cuvs/neighbors/common.h> | ||
|
||
#include <fstream> | ||
|
||
|
@@ -91,7 +92,8 @@ void _search(cuvsResources_t res, | |
cuvsCagraIndex index, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr); | ||
|
@@ -117,8 +119,27 @@ void _search(cuvsResources_t res, | |
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor); | ||
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor); | ||
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor); | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
if (filter.type == NO_FILTER) { | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
} else if (filter.type == BITSET) { | ||
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the build fails because of the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm actually I missed this. The function you are using here is creating a bitset from a list of indices and I don't think it is the workflow that we expect. |
||
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr); | ||
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor); | ||
cuvs::core::bitset<std::uint32_t, int64_t> removed_indices_bitset( | ||
*res_ptr, removed_indices, index_ptr->dataset().extent(0)); | ||
auto bitset_filter_obj = | ||
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view()); | ||
cuvs::neighbors::cagra::search(*res_ptr, | ||
search_params, | ||
*index_ptr, | ||
queries_mds, | ||
neighbors_mds, | ||
distances_mds, | ||
bitset_filter_obj); | ||
} else { | ||
RAFT_FAIL("Unsupported prefilter type: BITMAP"); | ||
} | ||
} | ||
|
||
template <typename T> | ||
|
@@ -213,7 +234,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
cuvsCagraIndex_t index_c_ptr, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
return cuvs::core::translate_exceptions([=] { | ||
auto queries = queries_tensor->dl_tensor; | ||
|
@@ -236,11 +258,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); | ||
|
||
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { | ||
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<float>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { | ||
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<int8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { | ||
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<uint8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else { | ||
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", | ||
queries.dtype.code, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API change needs to be propagated to:
cuvs/example/c
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done👍