-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add filtering for CAGRA to C API #452
base: branch-25.02
Are you sure you want to change the base?
Changes from 14 commits
59bb8ca
f37c58e
4fd78ca
e36d556
fad44c2
0d0184a
e3d33a0
d0dc961
3b17895
4e584cc
4e2202e
b0e8122
afbcf74
f6b8068
d7725db
9a64945
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
#include <cuvs/core/interop.hpp> | ||
#include <cuvs/neighbors/cagra.h> | ||
#include <cuvs/neighbors/cagra.hpp> | ||
#include <cuvs/neighbors/common.h> | ||
|
||
#include <fstream> | ||
|
||
|
@@ -91,7 +92,8 @@ void _search(cuvsResources_t res, | |
cuvsCagraIndex index, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr); | ||
|
@@ -117,8 +119,26 @@ void _search(cuvsResources_t res, | |
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor); | ||
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor); | ||
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor); | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
if (filter.type == NO_FILTER) { | ||
cuvs::neighbors::cagra::search( | ||
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); | ||
} else if (filter.type == BITSET) { | ||
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the build fails because of the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm actually I missed this. The function you are using here is creating a bitset from a list of indices and I don't think it is the workflow that we expect. |
||
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr); | ||
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor); | ||
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset( | ||
removed_indices, index_ptr->dataset().extent(0)); | ||
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset); | ||
cuvs::neighbors::cagra::search(*res_ptr, | ||
search_params, | ||
*index_ptr, | ||
queries_mds, | ||
neighbors_mds, | ||
distances_mds, | ||
bitset_filter_obj); | ||
} else { | ||
RAFT_FAIL("Unsupported filter type: BITMAP"); | ||
} | ||
} | ||
|
||
template <typename T> | ||
|
@@ -213,7 +233,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
cuvsCagraIndex_t index_c_ptr, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* neighbors_tensor, | ||
DLManagedTensor* distances_tensor) | ||
DLManagedTensor* distances_tensor, | ||
cuvsFilter filter) | ||
{ | ||
return cuvs::core::translate_exceptions([=] { | ||
auto queries = queries_tensor->dl_tensor; | ||
|
@@ -236,11 +257,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, | |
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); | ||
|
||
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { | ||
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<float>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { | ||
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<int8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { | ||
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); | ||
_search<uint8_t>( | ||
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter); | ||
} else { | ||
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", | ||
queries.dtype.code, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,8 @@ from libc.stdint cimport ( | |
|
||
from cuvs.common.exceptions import check_cuvs | ||
|
||
from cuvs.neighbors.filters import no_filter | ||
|
||
|
||
cdef class CompressionParams: | ||
""" | ||
|
@@ -480,7 +482,8 @@ def search(SearchParams search_params, | |
k, | ||
neighbors=None, | ||
distances=None, | ||
resources=None): | ||
resources=None, | ||
filter=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add this parameter to the python documentation |
||
""" | ||
Find the k nearest neighbors for each query. | ||
|
||
|
@@ -553,6 +556,9 @@ def search(SearchParams search_params, | |
_check_input_array(distances_cai, [np.dtype('float32')], | ||
exp_rows=n_queries, exp_cols=k) | ||
|
||
if filter is None: | ||
filter = no_filter() | ||
|
||
cdef cuvsCagraSearchParams* params = &search_params.params | ||
cdef cydlpack.DLManagedTensor* queries_dlpack = \ | ||
cydlpack.dlpack_c(queries_cai) | ||
|
@@ -569,7 +575,8 @@ def search(SearchParams search_params, | |
index.index, | ||
queries_dlpack, | ||
neighbors_dlpack, | ||
distances_dlpack | ||
distances_dlpack, | ||
filter.prefilter | ||
)) | ||
|
||
return (distances, neighbors) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,3 +95,52 @@ def from_bitmap(bitmap): | |
filter.addr = <uintptr_t> bitmap_dlpack | ||
|
||
return Prefilter(filter, parent=bitmap) | ||
|
||
def from_bitset(bitset): | ||
""" | ||
Create a pre-filter from an array with type of uint32. | ||
|
||
Parameters | ||
---------- | ||
bitmap : numpy.ndarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update this docstring for bitset instead of bitmap |
||
An array with type of `uint32` where each bit in the array corresponds | ||
to if a sample and query pair is greenlit (not filtered) or filtered. | ||
The array is row-major, meaning the bits are ordered by rows first. | ||
Each bit in a `uint32` element represents a different sample-query | ||
pair. | ||
|
||
- Bit value of 1: The sample-query pair is greenlit (allowed). | ||
- Bit value of 0: The sample-query pair is filtered. | ||
|
||
Returns | ||
------- | ||
filter : cuvs.neighbors.filters.Prefilter | ||
An instance of `Prefilter` that can be used to filter neighbors | ||
based on the given bitmap. | ||
{resources_docstring} | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import cupy as cp | ||
>>> import numpy as np | ||
>>> from cuvs.neighbors import filters | ||
>>> | ||
>>> n_samples = 50000 | ||
>>> n_queries = 1000 | ||
>>> | ||
>>> n_bitmap = np.ceil(n_samples * n_queries / 32).astype(int) | ||
>>> bitmap = cp.random.randint(1, 100, size=(n_bitmap,), dtype=cp.uint32) | ||
>>> prefilter = filters.from_bitmap(bitmap) | ||
""" | ||
bitset_cai = wrap_array(bitset) | ||
_check_input_array(bitset_cai, [np.dtype('uint32')]) | ||
|
||
cdef cydlpack.DLManagedTensor* bitset_dlpack = \ | ||
cydlpack.dlpack_c(bitset_cai) | ||
|
||
cdef cuvsFilter filter | ||
filter.type = BITSET | ||
filter.addr = <uintptr_t> bitset_dlpack | ||
|
||
return Prefilter(filter, parent=bitset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API change needs to be propagated to:
cuvs/example/c
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done👍