Skip to content

Commit

Permalink
[BUG] Fix CAGRA filter (#489)
Browse files Browse the repository at this point in the history
Ref : #472

## The cause of the bug
The bitonic sort was used on an array that was not a power of 2 long. In the current search implementation, the bitonic sort is used to move the invalid elements to the end of the buffer as:
https://github.com/rapidsai/cuvs/blob/5062594138a40231475299c7bac61083b0669fd1/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh#L758-L763
https://github.com/rapidsai/cuvs/blob/5062594138a40231475299c7bac61083b0669fd1/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh#L644-L649

The problem is that the (max) array length (=`MAX_ITOPK + MAX_CANDIDATES`) is not always the power of two.
These bitonic sorts are called even if no elements are filtered out unless `cuvs::neighbors::filtering::none_sample_filter` is specified as the filter, so #472 occurs.

## Fix
This PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #489
  • Loading branch information
enp1s0 authored Dec 4, 2024
1 parent fbbca05 commit acbd097
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 51 deletions.
16 changes: 13 additions & 3 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,27 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
(sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 +
sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width +
sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t);
smem_size = base_smem_size;

std::uint32_t additional_smem_size = 0;
if (num_itopk_candidates > 256) {
// Tentatively calculate the required share memory size when radix
// sort based topk is used, assuming the block size is the maximum.
if (itopk_size <= 256) {
smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t);
additional_smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t);
} else {
smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t);
additional_smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t);
}
}

if (!std::is_same_v<SAMPLE_FILTER_T, cuvs::neighbors::filtering::none_sample_filter>) {
// For filtering postprocess
using scan_op_t = cub::WarpScan<unsigned>;
additional_smem_size =
std::max<std::uint32_t>(additional_smem_size, sizeof(scan_op_t::TempStorage));
}

smem_size = base_smem_size + additional_smem_size;

uint32_t block_size = thread_block_size;
if (block_size == 0) {
block_size = min_block_size;
Expand Down
182 changes: 139 additions & 43 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const termin
}

template <unsigned MAX_CANDIDATES, class IdxT = void>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full(
float* candidate_distances, // [num_candidates]
IdxT* candidate_indices, // [num_candidates]
const std::uint32_t num_candidates,
Expand Down Expand Up @@ -215,7 +215,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st(
}

template <unsigned MAX_ITOPK, class IdxT = void>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge(
float* itopk_distances, // [num_itopk]
IdxT* itopk_indices, // [num_itopk]
const std::uint32_t num_itopk,
Expand Down Expand Up @@ -424,7 +424,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd(
template <unsigned MAX_ITOPK,
unsigned MAX_CANDIDATES,
class IdxT>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge(
float* itopk_distances, // [num_itopk]
IdxT* itopk_indices, // [num_itopk]
const std::uint32_t num_itopk,
Expand All @@ -437,20 +437,62 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(
const unsigned MULTI_WARPS_2)
{
// The results in candidate_distances/indices are sorted by bitonic sort.
topk_by_bitonic_sort_1st<MAX_CANDIDATES, IdxT>(
topk_by_bitonic_sort_and_full<MAX_CANDIDATES, IdxT>(
candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1);

// The results sorted above are merged with the internal intermediate top-k
// results so far using bitonic merge.
topk_by_bitonic_sort_2nd<MAX_ITOPK, IdxT>(itopk_distances,
itopk_indices,
num_itopk,
candidate_distances,
candidate_indices,
num_candidates,
work_buf,
first,
MULTI_WARPS_2);
topk_by_bitonic_sort_and_merge<MAX_ITOPK, IdxT>(itopk_distances,
itopk_indices,
num_itopk,
candidate_distances,
candidate_indices,
num_candidates,
work_buf,
first,
MULTI_WARPS_2);
}

// This function move the invalid index element to the end of the itopk list.
// Require : array_length % 32 == 0 && The invalid entry is only one.
template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array,
float* const distance_array,
const std::uint32_t array_length)
{
constexpr std::uint32_t warp_size = 32;
constexpr std::uint32_t invalid_index = utils::get_max_value<IdxT>();
const std::uint32_t lane_id = threadIdx.x % warp_size;

if (threadIdx.x >= warp_size) { return; }

bool found_invalid = false;
if (array_length % warp_size == 0) {
for (std::uint32_t i = lane_id; i < array_length; i += warp_size) {
const auto index = index_array[i];
const auto distance = distance_array[i];

if (found_invalid) {
index_array[i - 1] = index;
distance_array[i - 1] = distance;
} else {
// Check if the index is invalid
const auto I_found_invalid = (index == invalid_index);
const auto who_has_invalid = raft::ballot(I_found_invalid);
// if a value that is loaded by a smaller lane id thread, shift the array
if (who_has_invalid << (warp_size - lane_id)) {
index_array[i - 1] = index;
distance_array[i - 1] = distance;
}

found_invalid = who_has_invalid;
}
}
}
if (lane_id == 0) {
index_array[array_length - 1] = invalid_index;
distance_array[array_length - 1] = utils::get_max_value<float>();
}
}

template <class INDEX_T>
Expand Down Expand Up @@ -589,10 +631,10 @@ __device__ void search_core(
// sort
if constexpr (TOPK_BY_BITONIC_SORT) {
// [Notice]
// It is good to use multiple warps in topk_by_bitonic_sort() when
// It is good to use multiple warps in topk_by_bitonic_sort_and_merge() when
// batch size is small (short-latency), but it might not be always good
// when batch size is large (high-throughput).
// topk_by_bitonic_sort() consists of two operations:
// topk_by_bitonic_sort_and_merge() consists of two operations:
// if MAX_CANDIDATES is greater than 128, the first operation uses two warps;
// if MAX_ITOPK is greater than 256, the second operation used two warps.
const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0;
Expand All @@ -601,9 +643,9 @@ __device__ void search_core(
// reset small-hash table.
if ((iter + 1) % small_hash_reset_interval == 0) {
// Depending on the block size and the number of warps used in
// topk_by_bitonic_sort(), determine which warps are used to reset
// topk_by_bitonic_sort_and_merge(), determine which warps are used to reset
// the small hash and whether they are performed in overlap with
// topk_by_bitonic_sort().
// topk_by_bitonic_sort_and_merge().
_CLK_START();
unsigned hash_start_tid;
if (blockDim.x == 32) {
Expand All @@ -627,28 +669,28 @@ __device__ void search_core(

// topk with bitonic sort
_CLK_START();
if (std::is_same<SAMPLE_FILTER_T, cuvs::neighbors::filtering::none_sample_filter>::value ||
*filter_flag == 0) {
topk_by_bitonic_sort<MAX_ITOPK, MAX_CANDIDATES>(result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
__syncthreads();
} else {
topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
internal_topk,
false);
if (!(std::is_same<SAMPLE_FILTER_T, cuvs::neighbors::filtering::none_sample_filter>::value ||
*filter_flag == 0)) {
// Move the filtered out index to the end of the itopk list
for (unsigned i = 0; i < search_width; i++) {
move_invalid_to_end_of_list(
result_indices_buffer, result_distances_buffer, internal_topk);
}

if (threadIdx.x == 0) { *terminate_flag = 0; }
}
topk_by_bitonic_sort_and_merge<MAX_ITOPK, MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
__syncthreads();
_CLK_REC(clk_topk);
} else {
_CLK_START();
Expand Down Expand Up @@ -755,12 +797,66 @@ __device__ void search_core(
}

__syncthreads();
topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
top_k,
false);
// Move invalid index items to the end of the buffer without sorting the entire buffer
using scan_op_t = cub::WarpScan<unsigned>;
auto& temp_storage = *reinterpret_cast<typename scan_op_t::TempStorage*>(smem_work_ptr);

constexpr std::uint32_t warp_size = 32;
if (threadIdx.x < warp_size) {
std::uint32_t num_found_valid = 0;
for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk;
buffer_offset += warp_size) {
// Calculate the new buffer index
const auto src_position = buffer_offset + threadIdx.x;
const std::uint32_t is_valid_index =
(result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1;
std::uint32_t new_position;
scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position);
if (is_valid_index) {
const auto dst_position = num_found_valid + (new_position - 1);
result_indices_buffer[dst_position] = result_indices_buffer[src_position];
result_distances_buffer[dst_position] = result_distances_buffer[src_position];
}

// Calculate the largest valid position within a warp and bcast it for the next iteration
num_found_valid += new_position;
for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) {
const auto v = raft::shfl_xor(num_found_valid, offset);
if ((threadIdx.x & offset) == 0) { num_found_valid = v; }
}

// If the enough number of items are found, do early termination
if (num_found_valid >= top_k) { break; }
}

if (num_found_valid < top_k) {
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is
// usable in the next step
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) {
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
}
}

// If the sufficient number of valid indexes are not in the internal topk, pick up from the
// candidate list.
if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) {
__syncthreads();
const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0;
const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0;
topk_by_bitonic_sort_and_merge<MAX_ITOPK, MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
}
__syncthreads();
}

Expand Down
6 changes: 1 addition & 5 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;

// TODO: setting search_params.itopk_size here breaks the filter tests, but is required for
// k>1024 skip these tests until fixed
if (ps.k >= 1024) { GTEST_SKIP(); }
// search_params.itopk_size = ps.itopk_size;
search_params.itopk_size = ps.itopk_size;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand Down

0 comments on commit acbd097

Please sign in to comment.