Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filtering for CAGRA to C API #452

Merged
merged 25 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,22 @@ cuvsError_t cuvsCagraSearch(cuvsResources_t res,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
/*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] params cuvsCagraSearchParams_t used to search CAGRA index
* @param[in] index cuvsCagraIndex which has been returned by `cuvsCagraBuild`
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[in] filter Filter
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
*/
cuvsError_t cuvsCagraFilteredSearch(cuvsResources_t res,
cuvsCagraSearchParams_t params,
cuvsCagraIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances,
DLManagedTensor* filter);
ajit283 marked this conversation as resolved.
Show resolved Hide resolved

/**
* @}
Expand Down
75 changes: 72 additions & 3 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void _search(cuvsResources_t res,
cuvsCagraIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
std::optional<DLManagedTensor*> removed_indices_tensor = std::nullopt)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr);
Expand All @@ -117,8 +118,26 @@ void _search(cuvsResources_t res,
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);

if (removed_indices_tensor.has_value()) {
using filter_mdspan_type = raft::device_vector_view<int64_t, int64_t, raft::row_major>;
auto removed_indices =
cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor.value());
cuvs::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
*res_ptr, removed_indices, index_ptr->dataset().extent(0));
auto bitset_filter_obj =
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cuvs::neighbors::cagra::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);
} else {
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
}
}

template <typename T>
Expand Down Expand Up @@ -249,6 +268,56 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
});
}

extern "C" cuvsError_t cuvsCagraFilteredSearch(cuvsResources_t res,
cuvsCagraSearchParams_t params,
cuvsCagraIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor,
DLManagedTensor* filter_tensor)
{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
auto neighbors = neighbors_tensor->dl_tensor;
auto distances = distances_tensor->dl_tensor;
auto filter = filter_tensor->dl_tensor;

RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(queries),
"queries should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(neighbors),
"neighbors should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(distances),
"distances should have device compatible memory");
RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(filter),
"filter should have device compatible memory");

RAFT_EXPECTS(neighbors.dtype.code == kDLUInt && neighbors.dtype.bits == 32,
"neighbors should be of type uint32_t");
RAFT_EXPECTS(distances.dtype.code == kDLFloat && neighbors.dtype.bits == 32,
"distances should be of type float32");
RAFT_EXPECTS(filter.dtype.code == kDLInt && filter.dtype.bits == 64,
"filter should be of type int64_t");

auto index = *index_c_ptr;
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter_tensor);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter_tensor);
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) {
_search<uint8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter_tensor);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
queries.dtype.bits);
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params)
{
return cuvs::core::translate_exceptions([=] {
Expand Down
116 changes: 116 additions & 0 deletions cpp/test/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ float queries[4][2] = {{0.48216683, 0.0428398},
{0.51260436, 0.2643005},
{0.05198065, 0.5789965}};

int64_t filter[2] = {1, 2};
lowener marked this conversation as resolved.
Show resolved Hide resolved

uint32_t neighbors_exp[4] = {3, 0, 3, 1};
float distances_exp[4] = {0.03878258, 0.12472608, 0.04776672, 0.15224178};

uint32_t neighbors_exp_filtered[4] = {3, 0, 3, 0};
float distances_exp_filtered[4] = {0.03878258, 0.12472608, 0.04776672, 0.59063464};

TEST(CagraC, BuildSearch)
{
// create cuvsResources_t
Expand Down Expand Up @@ -126,3 +131,114 @@ TEST(CagraC, BuildSearch)
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildSearchFiltered)
{
// create cuvsResources_t
cuvsResources_t res;
cuvsResourcesCreate(&res);
cudaStream_t stream;
cuvsStreamGet(res, &stream);

// create dataset DLTensor
DLManagedTensor dataset_tensor;
dataset_tensor.dl_tensor.data = dataset;
dataset_tensor.dl_tensor.device.device_type = kDLCPU;
dataset_tensor.dl_tensor.ndim = 2;
dataset_tensor.dl_tensor.dtype.code = kDLFloat;
dataset_tensor.dl_tensor.dtype.bits = 32;
dataset_tensor.dl_tensor.dtype.lanes = 1;
int64_t dataset_shape[2] = {4, 2};
dataset_tensor.dl_tensor.shape = dataset_shape;
dataset_tensor.dl_tensor.strides = nullptr;

// create index
cuvsCagraIndex_t index;
cuvsCagraIndexCreate(&index);

// build index
cuvsCagraIndexParams_t build_params;
cuvsCagraIndexParamsCreate(&build_params);
cuvsCagraBuild(res, build_params, &dataset_tensor, index);

// create queries DLTensor
rmm::device_uvector<float> queries_d(4 * 2, stream);
raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream);

DLManagedTensor queries_tensor;
queries_tensor.dl_tensor.data = queries_d.data();
queries_tensor.dl_tensor.device.device_type = kDLCUDA;
queries_tensor.dl_tensor.ndim = 2;
queries_tensor.dl_tensor.dtype.code = kDLFloat;
queries_tensor.dl_tensor.dtype.bits = 32;
queries_tensor.dl_tensor.dtype.lanes = 1;
int64_t queries_shape[2] = {4, 2};
queries_tensor.dl_tensor.shape = queries_shape;
queries_tensor.dl_tensor.strides = nullptr;

// create filter DLTensor
rmm::device_uvector<int64_t> filter_d(2, stream);
raft::copy(filter_d.data(), (int64_t*)filter, 2, stream);

DLManagedTensor filter_tensor;
filter_tensor.dl_tensor.data = filter_d.data();
filter_tensor.dl_tensor.device.device_type = kDLCUDA;
filter_tensor.dl_tensor.ndim = 1;
filter_tensor.dl_tensor.dtype.code = kDLInt;
filter_tensor.dl_tensor.dtype.bits = 64;
filter_tensor.dl_tensor.dtype.lanes = 1;
int64_t filter_shape[1] = {2};
filter_tensor.dl_tensor.shape = filter_shape;
filter_tensor.dl_tensor.strides = nullptr;

// create neighbors DLTensor
rmm::device_uvector<uint32_t> neighbors_d(4, stream);

DLManagedTensor neighbors_tensor;
neighbors_tensor.dl_tensor.data = neighbors_d.data();
neighbors_tensor.dl_tensor.device.device_type = kDLCUDA;
neighbors_tensor.dl_tensor.ndim = 2;
neighbors_tensor.dl_tensor.dtype.code = kDLUInt;
neighbors_tensor.dl_tensor.dtype.bits = 32;
neighbors_tensor.dl_tensor.dtype.lanes = 1;
int64_t neighbors_shape[2] = {4, 1};
neighbors_tensor.dl_tensor.shape = neighbors_shape;
neighbors_tensor.dl_tensor.strides = nullptr;

// create distances DLTensor
rmm::device_uvector<float> distances_d(4, stream);

DLManagedTensor distances_tensor;
distances_tensor.dl_tensor.data = distances_d.data();
distances_tensor.dl_tensor.device.device_type = kDLCUDA;
distances_tensor.dl_tensor.ndim = 2;
distances_tensor.dl_tensor.dtype.code = kDLFloat;
distances_tensor.dl_tensor.dtype.bits = 32;
distances_tensor.dl_tensor.dtype.lanes = 1;
int64_t distances_shape[2] = {4, 1};
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

// search index
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
auto e = cuvsCagraFilteredSearch(res,
search_params,
index,
&queries_tensor,
&neighbors_tensor,
&distances_tensor,
&filter_tensor);

// verify output
ASSERT_TRUE(cuvs::devArrMatchHost(
neighbors_exp_filtered, neighbors_d.data(), 4, cuvs::Compare<uint32_t>()));
ASSERT_TRUE(cuvs::devArrMatchHost(
distances_exp_filtered, distances_d.data(), 4, cuvs::CompareApprox<float>(0.001f)));

// de-allocate index and res
cuvsCagraSearchParamsDestroy(search_params);
cuvsCagraIndexParamsDestroy(build_params);
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}
Loading