-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378
Changes from 19 commits
0faf889
f3388f0
f14be71
062ca87
a9fd8d8
8e27b74
5378827
651387f
757c222
018879f
bddae7f
caab88b
bac646d
f4c1922
0dc10a2
a73ba1f
2552d8d
ef734d4
0036127
2876506
b5dcc02
5c9c5de
d190b9d
9aa1bb1
a0fba17
e29d74d
4d0fc8e
1bcba66
0cafa23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,14 +21,17 @@ | |
#include "sample_filter_utils.cuh" | ||
#include "search_plan.cuh" | ||
#include "search_single_cta_inst.cuh" | ||
#include "utils.hpp" | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/linalg/unary_op.cuh> | ||
|
||
#include <cuvs/distance/distance.hpp> | ||
|
||
#include <cuvs/neighbors/brute_force.hpp> | ||
#include <cuvs/neighbors/cagra.hpp> | ||
|
||
// TODO: Fix these when ivf methods are moved over | ||
|
@@ -108,6 +111,115 @@ void search_main_core(raft::resources const& res, | |
} | ||
} | ||
|
||
/** | ||
* @brief Performs ANN search using brute force when filter sparsity exceeds a specified threshold. | ||
* | ||
* This function switches to a brute force search approach to improve recall rate when the | ||
* `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level | ||
* (proportion of unfiltered samples) exceeding the specified `threshold_to_bf`. | ||
* | ||
* @tparam T data element type | ||
* @tparam IdxT type of database vector indices | ||
* @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need | ||
* separate kernels for int/uint. | ||
* | ||
* @param[in] handle | ||
* @param[in] params configure the search | ||
* @param[in] strided_dataset CAGRA strided dataset | ||
* @param[in] metric distance type | ||
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] | ||
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset | ||
* [n_queries, k] | ||
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, | ||
* k] | ||
* @param[in] sample_filter a device filter function that greenlights samples for a given query | ||
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this | ||
* threshold, in the range [0, 1] | ||
* | ||
* @return true If the brute force search was applied successfully. | ||
* @return false If the brute force search was not applied. | ||
*/ | ||
template <typename T, | ||
typename InternalIdxT, | ||
typename CagraSampleFilterT, | ||
typename IdxT = uint32_t, | ||
typename DistanceT = float> | ||
bool search_using_brute_force( | ||
raft::resources const& res, | ||
const search_params& params, | ||
const strided_dataset<T, IdxT>& strided_dataset, | ||
cuvs::distance::DistanceType metric, | ||
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, | ||
raft::device_matrix_view<InternalIdxT, int64_t, raft::row_major> neighbors, | ||
raft::device_matrix_view<DistanceT, int64_t, raft::row_major> distances, | ||
CagraSampleFilterT& sample_filter, | ||
double threshold_to_bf = 0.9) | ||
{ | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto n_queries = queries.extent(0); | ||
auto n_dataset = strided_dataset.n_rows(); | ||
|
||
auto bitset_filter_view = sample_filter.bitset_view_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens here if the 2d bitmap isn't able to be converted to a 1d bitet without losing information? |
||
auto sparsity = bitset_filter_view.sparsity(res); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the number of positive bits in the bitmap also needed to compute this? But then we compute |
||
|
||
if (sparsity < threshold_to_bf) { return false; } | ||
|
||
// TODO: Support host dataset in `brute_force::build` | ||
RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity); | ||
using bitmap_view_t = cuvs::core::bitmap_view<const uint32_t, int64_t>; | ||
|
||
auto stream = raft::resource::get_cuda_stream(res); | ||
auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); | ||
|
||
rmm::device_uvector<uint32_t> raw_bitmap(bitmap_n_elements, stream); | ||
rmm::device_uvector<int64_t> raw_neighbors(neighbors.size(), stream); | ||
|
||
bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); | ||
|
||
auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); | ||
|
||
auto brute_force_neighbors = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>( | ||
raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); | ||
auto brute_force_dataset = raft::make_device_matrix_view<const T, int64_t, raft::row_major>( | ||
strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride()); | ||
|
||
auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is being called each and every time a user performs a search? There's overhead in this call, and this should cache off the built brute-force index because for many common distances this computes a set of norms. |
||
|
||
auto brute_force_queries = queries; | ||
auto padding_queries = raft::make_device_matrix<T, int64_t>(res, 0, 0); | ||
|
||
// Happens when the original dataset is a strided matrix. | ||
if (brute_force_dataset.extent(1) != queries.extent(1)) { | ||
padding_queries = raft::make_device_mdarray<T, int64_t>( | ||
res, | ||
raft::resource::get_workspace_resource(res), | ||
raft::make_extents<int64_t>(n_queries, brute_force_dataset.extent(1))); | ||
// Copy the queries and fill the padded elements with zeros | ||
raft::linalg::map_offset( | ||
res, | ||
padding_queries.view(), | ||
[queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) { | ||
auto row_ix = i / stride; | ||
auto el_ix = i % stride; | ||
return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0}; | ||
}); | ||
brute_force_queries = raft::make_device_matrix_view<const T, int64_t, raft::row_major>( | ||
padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1)); | ||
} | ||
cuvs::neighbors::brute_force::search( | ||
res, | ||
brute_force_idx, | ||
brute_force_queries, | ||
brute_force_neighbors, | ||
distances, | ||
cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); | ||
raft::linalg::unaryOp(neighbors.data_handle(), | ||
brute_force_neighbors.data_handle(), | ||
neighbors.size(), | ||
raft::cast_op<InternalIdxT>(), | ||
raft::resource::get_cuda_stream(res)); | ||
return true; | ||
} | ||
|
||
/** | ||
* @brief Search ANN using the constructed index. | ||
* | ||
|
@@ -126,6 +238,7 @@ void search_main_core(raft::resources const& res, | |
* [n_queries, k] | ||
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, | ||
* k] | ||
* @param[in] sample_filter a device filter function that greenlights samples for a given query | ||
*/ | ||
template <typename T, | ||
typename InternalIdxT, | ||
|
@@ -150,6 +263,14 @@ void search_main(raft::resources const& res, | |
// Dispatch search parameters based on the dataset kind. | ||
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data()); | ||
strided_dset != nullptr) { | ||
if constexpr (!std::is_same_v<CagraSampleFilterT, | ||
cuvs::neighbors::filtering::none_sample_filter> && | ||
(std::is_same_v<T, float> || std::is_same_v<T, half>)) { | ||
bool bf_search_done = search_using_brute_force( | ||
res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); | ||
if (bf_search_done) return; | ||
} | ||
|
||
// Search using a plain (strided) row-major dataset | ||
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>( | ||
res, params, *strided_dset, index.metric()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please make it possible to disable/enable this feature by the user (see #252 (comment) for the reasoning):
threshold_to_bf
as CAGRA search parameter and set it to 1.0 by default therethreshold_to_bf >= 1.0
then disable further checks and proceed with CAGRA search immediately (i.e. no need to run the sparsity check).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most users with a filter are going to be specifying the filter in batch, and will know the sparsity of the filter. I suggest instead of turning this feature off by default, we allow the user specified filter to know its own nnz unless updated.
Turning this off by default undermines the fundamental benefits of this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most users are not specifying a filter, and when they do, it's expected the filter is going to be heavy. This should not impact all users.