From acbd097ed15afe186367b5a46a5d4b366ac9d804 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:06:17 +0900 Subject: [PATCH] [BUG] Fix CAGRA filter (#489) Ref : https://github.com/rapidsai/cuvs/issues/472 ## The cause of the bug The bitonic sort was used on an array that was not a power of 2 long. In the current search implementation, the bitonic sort is used to move the invalid elements to the end of the buffer as: https://github.com/rapidsai/cuvs/blob/5062594138a40231475299c7bac61083b0669fd1/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh#L758-L763 https://github.com/rapidsai/cuvs/blob/5062594138a40231475299c7bac61083b0669fd1/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh#L644-L649 The problem is that the (max) array length (=`MAX_ITOPK + MAX_CANDIDATES`) is not always the power of two. These bitonic sorts are called even if no elements are filtered out unless `cuvs::neighbors::filtering::none_sample_filter` is specified as the filter, so #472 occurs. ## Fix This PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer. Authors: - tsuki (https://github.com/enp1s0) Approvers: - Artem M. Chirkin (https://github.com/achirkin) URL: https://github.com/rapidsai/cuvs/pull/489 --- .../detail/cagra/search_single_cta.cuh | 16 +- .../cagra/search_single_cta_kernel-inl.cuh | 182 +++++++++++++----- cpp/test/neighbors/ann_cagra.cuh | 6 +- 3 files changed, 153 insertions(+), 51 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 2bed19009..fa71dbaf9 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -129,17 +129,27 @@ struct search : search_plan_impl { (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); - smem_size = base_smem_size; + + std::uint32_t additional_smem_size = 0; if (num_itopk_candidates > 256) { // Tentatively calculate the required share memory size when radix // sort based topk is used, assuming the block size is the maximum. if (itopk_size <= 256) { - smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); + additional_smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); } else { - smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); + additional_smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); } } + if (!std::is_same_v) { + // For filtering postprocess + using scan_op_t = cub::WarpScan; + additional_smem_size = + std::max(additional_smem_size, sizeof(scan_op_t::TempStorage)); + } + + smem_size = base_smem_size + additional_smem_size; + uint32_t block_size = thread_block_size; if (block_size == 0) { block_size = min_block_size; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 79cb6bc10..678ed0cb4 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -111,7 +111,7 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const termin } template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st( +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( float* candidate_distances, // [num_candidates] IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, @@ -215,7 +215,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st( } template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd( +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] const std::uint32_t num_itopk, @@ -424,7 +424,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd( template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] const std::uint32_t num_itopk, @@ -437,20 +437,62 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( const unsigned MULTI_WARPS_2) { // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( + topk_by_bitonic_sort_and_full( candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1); // The results sorted above are merged with the internal intermediate top-k // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first, - MULTI_WARPS_2); + topk_by_bitonic_sort_and_merge(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first, + MULTI_WARPS_2); +} + +// This function move the invalid index element to the end of the itopk list. +// Require : array_length % 32 == 0 && The invalid entry is only one. +template +RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array, + float* const distance_array, + const std::uint32_t array_length) +{ + constexpr std::uint32_t warp_size = 32; + constexpr std::uint32_t invalid_index = utils::get_max_value(); + const std::uint32_t lane_id = threadIdx.x % warp_size; + + if (threadIdx.x >= warp_size) { return; } + + bool found_invalid = false; + if (array_length % warp_size == 0) { + for (std::uint32_t i = lane_id; i < array_length; i += warp_size) { + const auto index = index_array[i]; + const auto distance = distance_array[i]; + + if (found_invalid) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } else { + // Check if the index is invalid + const auto I_found_invalid = (index == invalid_index); + const auto who_has_invalid = raft::ballot(I_found_invalid); + // if a value that is loaded by a smaller lane id thread, shift the array + if (who_has_invalid << (warp_size - lane_id)) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } + + found_invalid = who_has_invalid; + } + } + } + if (lane_id == 0) { + index_array[array_length - 1] = invalid_index; + distance_array[array_length - 1] = utils::get_max_value(); + } } template @@ -589,10 +631,10 @@ __device__ void search_core( // sort if constexpr (TOPK_BY_BITONIC_SORT) { // [Notice] - // It is good to use multiple warps in topk_by_bitonic_sort() when + // It is good to use multiple warps in topk_by_bitonic_sort_and_merge() when // batch size is small (short-latency), but it might not be always good // when batch size is large (high-throughput). - // topk_by_bitonic_sort() consists of two operations: + // topk_by_bitonic_sort_and_merge() consists of two operations: // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; // if MAX_ITOPK is greater than 256, the second operation used two warps. const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; @@ -601,9 +643,9 @@ __device__ void search_core( // reset small-hash table. if ((iter + 1) % small_hash_reset_interval == 0) { // Depending on the block size and the number of warps used in - // topk_by_bitonic_sort(), determine which warps are used to reset + // topk_by_bitonic_sort_and_merge(), determine which warps are used to reset // the small hash and whether they are performed in overlap with - // topk_by_bitonic_sort(). + // topk_by_bitonic_sort_and_merge(). _CLK_START(); unsigned hash_start_tid; if (blockDim.x == 32) { @@ -627,28 +669,28 @@ __device__ void search_core( // topk with bitonic sort _CLK_START(); - if (std::is_same::value || - *filter_flag == 0) { - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); - __syncthreads(); - } else { - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - internal_topk, - false); + if (!(std::is_same::value || + *filter_flag == 0)) { + // Move the filtered out index to the end of the itopk list + for (unsigned i = 0; i < search_width; i++) { + move_invalid_to_end_of_list( + result_indices_buffer, result_distances_buffer, internal_topk); + } + if (threadIdx.x == 0) { *terminate_flag = 0; } } + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0), + multi_warps_1, + multi_warps_2); + __syncthreads(); _CLK_REC(clk_topk); } else { _CLK_START(); @@ -755,12 +797,66 @@ __device__ void search_core( } __syncthreads(); - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - top_k, - false); + // Move invalid index items to the end of the buffer without sorting the entire buffer + using scan_op_t = cub::WarpScan; + auto& temp_storage = *reinterpret_cast(smem_work_ptr); + + constexpr std::uint32_t warp_size = 32; + if (threadIdx.x < warp_size) { + std::uint32_t num_found_valid = 0; + for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; + buffer_offset += warp_size) { + // Calculate the new buffer index + const auto src_position = buffer_offset + threadIdx.x; + const std::uint32_t is_valid_index = + (result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; + std::uint32_t new_position; + scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); + if (is_valid_index) { + const auto dst_position = num_found_valid + (new_position - 1); + result_indices_buffer[dst_position] = result_indices_buffer[src_position]; + result_distances_buffer[dst_position] = result_distances_buffer[src_position]; + } + + // Calculate the largest valid position within a warp and bcast it for the next iteration + num_found_valid += new_position; + for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { + const auto v = raft::shfl_xor(num_found_valid, offset); + if ((threadIdx.x & offset) == 0) { num_found_valid = v; } + } + + // If the enough number of items are found, do early termination + if (num_found_valid >= top_k) { break; } + } + + if (num_found_valid < top_k) { + // Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is + // usable in the next step + for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + } + + // If the sufficient number of valid indexes are not in the internal topk, pick up from the + // candidate list. + if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { + __syncthreads(); + const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; + const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0), + multi_warps_1, + multi_warps_2); + } __syncthreads(); } diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 660246c67..8d5701439 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -758,11 +758,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; - - // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for - // k>1024 skip these tests until fixed - if (ps.k >= 1024) { GTEST_SKIP(); } - // search_params.itopk_size = ps.itopk_size; + search_params.itopk_size = ps.itopk_size; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim);