Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0faf889
[Feat] CAGRA filtering with BFKNN when sparsity matching threshold
rhdong Oct 2, 2024
f3388f0
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 3, 2024
f14be71
revert: update_dataset on strided matrix
rhdong Oct 3, 2024
062ca87
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 4, 2024
a9fd8d8
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 22, 2024
8e27b74
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 28, 2024
5378827
Support strided matrix on queries & respond to the review comments
rhdong Oct 29, 2024
651387f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 29, 2024
757c222
fix a style issue
rhdong Oct 29, 2024
018879f
Merge remote-tracking branch 'rhdong/rhdong/cagra-bf' into rhdong/cag…
rhdong Oct 29, 2024
bddae7f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 30, 2024
caab88b
fix: don't invoke 'copy_with_padding' from `src/neighbors/detail`
rhdong Oct 30, 2024
bac646d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 31, 2024
f4c1922
optimize by review comments
rhdong Oct 31, 2024
0dc10a2
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 13, 2024
a73ba1f
move calling down to branch & replace copy_with_padding
rhdong Nov 14, 2024
2552d8d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
ef734d4
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
0036127
fix: RAFT_LOG_DEBUG %f for double & other optimization
rhdong Nov 15, 2024
2876506
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
b5dcc02
benchmark: support pre-filter on CAGRA
rhdong Nov 18, 2024
5c9c5de
adjust the kernel selection condition to be 0.9f
rhdong Nov 18, 2024
d190b9d
expose the threshold-to-bf to callers & test cases
rhdong Nov 19, 2024
9aa1bb1
move the threshold-to-bf into search_params
rhdong Nov 19, 2024
a0fba17
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
e29d74d
skip the test on half when cusparse version is unsupported.
rhdong Nov 20, 2024
4d0fc8e
Revert "benchmark: support pre-filter on CAGRA"
rhdong Nov 20, 2024
1bcba66
if (params.threshold_to_bf >= 1.0) { return false; }
rhdong Nov 20, 2024
0cafa23
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,13 @@ struct index : cuvs::neighbors::index {
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_graph(res, knn_graph);
if constexpr (raft::is_device_mdspan_v<decltype(dataset)>) {
contiguous_dataset_ =
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
} else {
contiguous_dataset_ =
raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
}

raft::resource::sync_stream(res);
}
Expand All @@ -417,13 +424,19 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset)
{
contiguous_dataset_ = std::monostate{};
if (dataset.stride(0) == dataset.extent(1) && dataset.stride(1) == 1) {
contiguous_dataset_ =
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
}
dataset_ = make_aligned_dataset(res, dataset, 16);
}

Expand All @@ -436,7 +449,8 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/**
Expand All @@ -447,14 +461,16 @@ struct index : cuvs::neighbors::index {
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
contiguous_dataset_ = std::monostate{};
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::move(dataset);
contiguous_dataset_ = std::monostate{};
dataset_ = std::move(dataset);
}

/**
Expand Down Expand Up @@ -492,11 +508,17 @@ struct index : cuvs::neighbors::index {
graph_view_ = graph_.view();
}

auto contiguous_dataset() const { return contiguous_dataset_; }

private:
cuvs::distance::DistanceType metric_;
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
std::variant<std::monostate,
raft::device_matrix_view<const T, int64_t, raft::row_major>,
raft::host_matrix_view<const T, int64_t, raft::row_major>>
contiguous_dataset_ = std::monostate{};
};
/**
* @}
Expand Down
57 changes: 57 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/unary_op.cuh>

#include <cuvs/distance/distance.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/cagra.hpp>

// TODO: Fix these when ivf methods are moved over
Expand Down Expand Up @@ -140,6 +142,61 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<DistanceT, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
if constexpr (!std::is_same_v<CagraSampleFilterT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pull this out into a separate function that can be invoked here please? This search_main function is gettig pretty massive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cuvs::neighbors::filtering::none_sample_filter> &&
(std::is_same_v<T, float> || std::is_same_v<T, half>)) {
auto n_queries = queries.extent(0);
auto n_dataset = index.size();

auto bitset_filter_view = sample_filter.bitset_view_;
auto dataset_view = index.contiguous_dataset();

auto sparsity = bitset_filter_view.sparsity(res);
constexpr double threshold_to_bf = 0.9;

// TODO: Support host dataset in `brute_force::build`
if (sparsity >= threshold_to_bf &&
std::holds_alternative<raft::device_matrix_view<const T, int64_t, raft::row_major>>(
dataset_view)) {
using bitmap_view_t = cuvs::core::bitmap_view<const uint32_t, int64_t>;

auto stream = raft::resource::get_cuda_stream(res);
auto bitmap_n_elements =
bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries);

rmm::device_uvector<uint32_t> raw_bitmap(bitmap_n_elements, stream);
rmm::device_uvector<int64_t> raw_neighbors(neighbors.size(), stream);

bitset_filter_view.repeat(res, n_queries, raw_bitmap.data());

auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset);

auto brute_force_neighbors = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>(
raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1));
auto brute_force_dataset =
std::get_if<raft::device_matrix_view<const T, int64_t, raft::row_major>>(&dataset_view);

if (brute_force_dataset) {
RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity);
auto brute_force_idx =
cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric());
cuvs::neighbors::brute_force::search(
res,
brute_force_idx,
queries,
brute_force_neighbors,
distances,
cuvs::neighbors::filtering::bitmap_filter(brute_force_filter));
raft::linalg::unaryOp(neighbors.data_handle(),
brute_force_neighbors.data_handle(),
neighbors.size(),
raft::cast_op<InternalIdxT>(),
raft::resource::get_cuda_stream(res));
return;
}
}
}

auto stream = raft::resource::get_cuda_stream(res);
const auto& graph = index.graph();
auto graph_internal = raft::make_device_matrix_view<const InternalIdxT, int64_t, raft::row_major>(
Expand Down
68 changes: 49 additions & 19 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra {
namespace {

struct test_cagra_sample_filter {
static constexpr unsigned offset = 300;
inline _RAFT_HOST_DEVICE auto operator()(
// query index
const uint32_t query_ix,
// the index of the current sample inside the current inverted list
const uint32_t sample_ix) const
const uint32_t sample_ix,
const uint32_t offset) const
{
return sample_ix >= offset;
}
Expand Down Expand Up @@ -276,6 +276,7 @@ struct AnnCagraInputs {
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
uint32_t filter_offset = 300;
std::optional<float> ivf_pq_search_refine_ratio = std::nullopt;
std::optional<vpq_params> compression = std::nullopt;

Expand Down Expand Up @@ -702,21 +703,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(
handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - test_cagra_sample_filter::offset,
ps.dim,
ps.k,
ps.metric);
auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - ps.filter_offset,
ps.dim,
ps.k,
ps.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_cagra_sample_filter::offset),
IdxT(ps.filter_offset),
queries_size,
stream_);
raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
Expand Down Expand Up @@ -787,7 +787,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto dists_out_view = raft::make_device_matrix_view<DistanceT, int64_t>(
distances_dev.data(), ps.n_queries, ps.k);
auto removed_indices =
raft::make_device_vector<int64_t, int64_t>(handle_, test_cagra_sample_filter::offset);
raft::make_device_vector<int64_t, int64_t>(handle_, ps.filter_offset);
thrust::sequence(
raft::resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
Expand All @@ -813,8 +813,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
bool unacceptable_node = false;
for (int q = 0; q < ps.n_queries; q++) {
for (int i = 0; i < ps.k; i++) {
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n);
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node =
unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset);
}
}
EXPECT_FALSE(unacceptable_node);
Expand Down Expand Up @@ -1002,6 +1003,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{false, true},
{false},
{0.99},
{uint32_t(300)},
{1.0f, 2.0f, 3.0f});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -1028,6 +1030,34 @@ inline std::vector<AnnCagraInputs> generate_inputs()
return inputs;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
inline std::vector<AnnCagraInputs> generate_bf_inputs()
{
// Add test cases for brute force as sparsity >= 0.9.
std::vector<AnnCagraInputs> inputs_for_brute_force;
auto inputs_original = raft::util::itertools::product<AnnCagraInputs>(
{100},
{10000, 100000},
{1, 8, 17},
{1, 16, 256}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0, 1, 10, 100},
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{1.0});
for (auto input : inputs_original) {
input.filter_offset = 0.90 * input.n_rows;
inputs_for_brute_force.push_back(input);
}

return inputs_for_brute_force;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
const std::vector<AnnCagraInputs> inputs_brute_force = generate_bf_inputs();

} // namespace cuvs::neighbors::cagra
3 changes: 3 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest,
AnnCagraAddNodesTestF_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest,
AnnCagraFilterTestF_U32,
::testing::ValuesIn(inputs_brute_force));

} // namespace cuvs::neighbors::cagra
Loading