Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filtering for CAGRA to C API #452

Open
wants to merge 16 commits into
base: branch-25.02
Choose a base branch
from
6 changes: 5 additions & 1 deletion cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>
Expand Down Expand Up @@ -385,13 +386,16 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
* @param[in] prefilter cuvsFilter input prefilter that can be used
to filter queries and neighbors based on the given bitset.
*/
cuvsError_t cuvsCagraSearch(cuvsResources_t res,
cuvsCagraSearchParams_t params,
cuvsCagraIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
DLManagedTensor* distances,
cuvsFilter filter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API change needs to be propagated to:

  • the python package
  • the example C project (cuvs/example/c)
  • probably the rust package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done👍


/**
* @}
Expand Down
39 changes: 32 additions & 7 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.h>

#include <fstream>

Expand Down Expand Up @@ -91,7 +92,8 @@ void _search(cuvsResources_t res,
cuvsCagraIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr);
Expand All @@ -117,8 +119,27 @@ void _search(cuvsResources_t res,
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
if (filter.type == NO_FILTER) {
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<int64_t, int64_t, raft::row_major>;
ajit283 marked this conversation as resolved.
Show resolved Hide resolved
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
*res_ptr, removed_indices, index_ptr->dataset().extent(0));
auto bitset_filter_obj =
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cuvs::neighbors::cagra::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);
} else {
RAFT_FAIL("Unsupported prefilter type: BITMAP");
}
}

template <typename T>
Expand Down Expand Up @@ -213,7 +234,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
cuvsCagraIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
Expand All @@ -236,11 +258,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<int8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) {
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<uint8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
123 changes: 122 additions & 1 deletion cpp/test/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ float queries[4][2] = {{0.48216683, 0.0428398},
{0.51260436, 0.2643005},
{0.05198065, 0.5789965}};

int64_t filter[2] = {1, 2};
lowener marked this conversation as resolved.
Show resolved Hide resolved

uint32_t neighbors_exp[4] = {3, 0, 3, 1};
float distances_exp[4] = {0.03878258, 0.12472608, 0.04776672, 0.15224178};

uint32_t neighbors_exp_filtered[4] = {3, 0, 3, 0};
float distances_exp_filtered[4] = {0.03878258, 0.12472608, 0.04776672, 0.59063464};

TEST(CagraC, BuildSearch)
{
// create cuvsResources_t
Expand Down Expand Up @@ -109,10 +114,15 @@ TEST(CagraC, BuildSearch)
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

cuvsFilter prefilter;
prefilter.type = NO_FILTER;
prefilter.addr = (uintptr_t)NULL;

// search index
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
cuvsCagraSearch(res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor);
cuvsCagraSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, prefilter);

// verify output
ASSERT_TRUE(
Expand All @@ -126,3 +136,114 @@ TEST(CagraC, BuildSearch)
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildSearchFiltered)
{
// create cuvsResources_t
cuvsResources_t res;
cuvsResourcesCreate(&res);
cudaStream_t stream;
cuvsStreamGet(res, &stream);

// create dataset DLTensor
DLManagedTensor dataset_tensor;
dataset_tensor.dl_tensor.data = dataset;
dataset_tensor.dl_tensor.device.device_type = kDLCPU;
dataset_tensor.dl_tensor.ndim = 2;
dataset_tensor.dl_tensor.dtype.code = kDLFloat;
dataset_tensor.dl_tensor.dtype.bits = 32;
dataset_tensor.dl_tensor.dtype.lanes = 1;
int64_t dataset_shape[2] = {4, 2};
dataset_tensor.dl_tensor.shape = dataset_shape;
dataset_tensor.dl_tensor.strides = nullptr;

// create index
cuvsCagraIndex_t index;
cuvsCagraIndexCreate(&index);

// build index
cuvsCagraIndexParams_t build_params;
cuvsCagraIndexParamsCreate(&build_params);
cuvsCagraBuild(res, build_params, &dataset_tensor, index);

// create queries DLTensor
rmm::device_uvector<float> queries_d(4 * 2, stream);
raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream);

DLManagedTensor queries_tensor;
queries_tensor.dl_tensor.data = queries_d.data();
queries_tensor.dl_tensor.device.device_type = kDLCUDA;
queries_tensor.dl_tensor.ndim = 2;
queries_tensor.dl_tensor.dtype.code = kDLFloat;
queries_tensor.dl_tensor.dtype.bits = 32;
queries_tensor.dl_tensor.dtype.lanes = 1;
int64_t queries_shape[2] = {4, 2};
queries_tensor.dl_tensor.shape = queries_shape;
queries_tensor.dl_tensor.strides = nullptr;

// create neighbors DLTensor
rmm::device_uvector<uint32_t> neighbors_d(4, stream);

DLManagedTensor neighbors_tensor;
neighbors_tensor.dl_tensor.data = neighbors_d.data();
neighbors_tensor.dl_tensor.device.device_type = kDLCUDA;
neighbors_tensor.dl_tensor.ndim = 2;
neighbors_tensor.dl_tensor.dtype.code = kDLUInt;
neighbors_tensor.dl_tensor.dtype.bits = 32;
neighbors_tensor.dl_tensor.dtype.lanes = 1;
int64_t neighbors_shape[2] = {4, 1};
neighbors_tensor.dl_tensor.shape = neighbors_shape;
neighbors_tensor.dl_tensor.strides = nullptr;

// create distances DLTensor
rmm::device_uvector<float> distances_d(4, stream);

DLManagedTensor distances_tensor;
distances_tensor.dl_tensor.data = distances_d.data();
distances_tensor.dl_tensor.device.device_type = kDLCUDA;
distances_tensor.dl_tensor.ndim = 2;
distances_tensor.dl_tensor.dtype.code = kDLFloat;
distances_tensor.dl_tensor.dtype.bits = 32;
distances_tensor.dl_tensor.dtype.lanes = 1;
int64_t distances_shape[2] = {4, 1};
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

// create filter DLTensor
rmm::device_uvector<int64_t> filter_d(2, stream);
raft::copy(filter_d.data(), (int64_t*)filter, 2, stream);

cuvsFilter filter;

DLManagedTensor filter_tensor;
filter_tensor.dl_tensor.data = filter_d.data();
filter_tensor.dl_tensor.device.device_type = kDLCUDA;
filter_tensor.dl_tensor.ndim = 1;
filter_tensor.dl_tensor.dtype.code = kDLInt;
filter_tensor.dl_tensor.dtype.bits = 64;
filter_tensor.dl_tensor.dtype.lanes = 1;
int64_t filter_shape[1] = {2};
filter_tensor.dl_tensor.shape = filter_shape;
filter_tensor.dl_tensor.strides = nullptr;

filter.type = BITSET;
filter.addr = (uintptr_t)&filter_tensor;

// search index
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
cuvsCagraSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter);

// verify output
ASSERT_TRUE(cuvs::devArrMatchHost(
neighbors_exp_filtered, neighbors_d.data(), 4, cuvs::Compare<uint32_t>()));
ASSERT_TRUE(cuvs::devArrMatchHost(
distances_exp_filtered, distances_d.data(), 4, cuvs::CompareApprox<float>(0.001f)));

// de-allocate index and res
cuvsCagraSearchParamsDestroy(search_params);
cuvsCagraIndexParamsDestroy(build_params);
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}