-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Fix CAGRA filter #489
Conversation
Can you add a test that would prevent regression? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, @enp1s0!
I'm a little bit confused with the description. Do I understand it right that this PR contains two fixes: (1) make the bitonic sort array always a power-of-two, (2) move filtered elements to the end of the topk buffer?
The big chunk of the PR addresses (2), but that should be irrelevant for #472, because in that bug no elements are filtered out.
Therefore, I think, it would be really beneficial to construct a reproducer for #472 as a test case in this PR and make sure it's fixed with the introduced change.
Also, (1) did you have a chance to check if this affects the QPS? (2) do we need a similar fix for multi-cta and multi-kernel versions of CAGRA? |
@achirkin, thank you for your comment, and I'm sorry for the bad PR description. I updated it.
No, this PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer. In the current search implementation, the bitonic sort is used to move the invalid elements as: cuvs/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh Lines 758 to 763 in 5062594
cuvs/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh Lines 644 to 649 in 5062594
The problem is that the (max) array length (= Although, as you mentioned, making the bitonic sort array always a power-of-two is an alternative way to fix this issue, I didn't do it because 1) the array elements except the filtered-out nodes are already sorted, and 2) more registers are required that will not be used but required to make the bitonic sort array power-of-two. Also, this bug is the cause of a problem in the CAGRA filtering unit test: cuvs/cpp/test/neighbors/ann_cagra.cuh Line 762 in 5062594
When itop_k is not specified, the default value, 64, is used. The graph degree is also 64. Therefore, MAX_ITOPK (64) + MAX_CANDIDATES (64) equals 128, and the bitonic sort works correctly in this case. However, if itopk size is set to another value, the bitonic sort does not work.
Yes, so I reenabled the test in this PR by changing the following lines to set the itopk size correctly. (@lowener) cuvs/cpp/test/neighbors/ann_cagra.cuh Lines 762 to 765 in 5062594
I measured the performance of no filtering out search (the same situation as #472 )
No.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @enp1s0 for the PR and the comprehensive answer. Now I understand the logic of the change it looks good to me overall.
Nevertheless, I have a few questions to the design below.
cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Outdated
Show resolved
Hide resolved
for (unsigned i = 0; i < search_width; i++) { | ||
move_invalid_to_end_of_list( | ||
result_indices_buffer, result_distances_buffer, internal_topk); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand that right, that this algorithm moves one index at a time, and repeats this for each candidate in the list? That's O(search_width*(parent_list_buffer + search_width))
complexity?
Maybe we'd better do a prefix scan (to get the indices of the valid items) followed by a shift of all elements (e.g. we can use cub::WarpScan
or cub::BlockScan
for that)? Or do you think the performance difference would be negligible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I didn't use the scan method is that the maximum number of filtered-out elements here is search_width
, which is typically 1. In our experience, increasing search_width
does not help improve the search performance much, so it is likely to be a small number. So I use the specialized function for 1 here and run it multiple times if search_width
is not 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for the explanation. For some reason I thought we could have more than search_width
new candidates each iteration.
cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Outdated
Show resolved
Hide resolved
constexpr std::uint32_t warp_size = 32; | ||
if (threadIdx.x < warp_size) { | ||
std::uint32_t num_found_valid = 0; | ||
for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; | ||
buffer_offset += warp_size) { | ||
// Calculate the new buffer index | ||
const auto src_position = buffer_offset + threadIdx.x; | ||
const std::uint32_t is_valid_index = | ||
(result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; | ||
std::uint32_t new_position; | ||
scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); | ||
if (is_valid_index) { | ||
const auto dst_position = num_found_valid + (new_position - 1); | ||
result_indices_buffer[dst_position] = result_indices_buffer[src_position]; | ||
result_distances_buffer[dst_position] = result_distances_buffer[src_position]; | ||
} | ||
|
||
// Calculate the largest valid position within a warp and bcast it for the next iteration | ||
num_found_valid += new_position; | ||
for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { | ||
const auto v = __shfl_xor_sync(~0u, num_found_valid, offset); | ||
if ((threadIdx.x & offset) == 0) { num_found_valid = v; } | ||
} | ||
|
||
// If the enough number of items are found, do early termination | ||
if (num_found_valid >= top_k) { break; } | ||
} | ||
|
||
if (num_found_valid < top_k) { | ||
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort` is usable in | ||
// the next step | ||
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { | ||
result_indices_buffer[i] = invalid_index; | ||
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just realized you do here exactly what I wanted to suggest in the comment above. Could you please put this in a separate function for better readability? And then consider if it makes sense to re-use that in place of the move_invalid_to_end_of_list
loop above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use this code here because the number of filtered-out nodes is unknown, while the maximum number is known in the above place. If the maximum number of filtered-out nodes is known and small, we can use more simple code that is not required to do the scan operation like above (although I didn't compare the performance experimentally).
// If the sufficient number of valid indexes are not in the internal topk, pick up from the | ||
// candidate list. | ||
if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand it right, that we need this because we filter the result indices buffer after we move candidates from the internal workspace, which is larger?
If so, why shouldn't we first filter the internal workspace and only then copy the results instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What the code here does is 1) pick up valid nodes (=not filtered-out nodes) from the candidate list by the bitonic sort and 2) concatenate the resulting list and the itopk valid nodes list. This operation is needed when enough valid indices for topk are not obtained only from the itopk valid node list.
I think the function name topk_by_bitonic_sort
is problematic because it is not a simple sort but takes a sorted list A and an unsorted list B and outputs merge_sort(A, bitonic_sort(B)). I'll change the name.
Co-authored-by: Artem M. Chirkin <[email protected]>
Co-authored-by: Artem M. Chirkin <[email protected]>
@achirkin Thank you for your review! I fixed the code, so can you check it again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updates! LGTM
std::uint32_t num_found_valid = 0; | ||
for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; | ||
buffer_offset += warp_size) { | ||
// Calculate the new buffer index | ||
const auto src_position = buffer_offset + threadIdx.x; | ||
const std::uint32_t is_valid_index = | ||
(result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; | ||
std::uint32_t new_position; | ||
scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); | ||
if (is_valid_index) { | ||
const auto dst_position = num_found_valid + (new_position - 1); | ||
result_indices_buffer[dst_position] = result_indices_buffer[src_position]; | ||
result_distances_buffer[dst_position] = result_distances_buffer[src_position]; | ||
} | ||
|
||
// Calculate the largest valid position within a warp and bcast it for the next iteration | ||
num_found_valid += new_position; | ||
for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { | ||
const auto v = raft::shfl_xor(num_found_valid, offset); | ||
if ((threadIdx.x & offset) == 0) { num_found_valid = v; } | ||
} | ||
|
||
// If the enough number of items are found, do early termination | ||
if (num_found_valid >= top_k) { break; } | ||
} | ||
|
||
if (num_found_valid < top_k) { | ||
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is | ||
// usable in the next step | ||
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { | ||
result_indices_buffer[i] = invalid_index; | ||
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If if we use this routine only once, I still think it would be nice to move it out as a separate function alongside the move_invalid_to_end_of_list
(so that we'd have two: move_first_invalid_to_end_of_list
and move_all_invalid_to_end_of_list
).
But given that we're very close to the code freeze and this PR is very important, I'd say we can postpone this grooming to the next release.
/merge |
Ref : #472
The cause of the bug
The bitonic sort was used on an array that was not a power of 2 long. In the current search implementation, the bitonic sort is used to move the invalid elements to the end of the buffer as:
cuvs/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Lines 758 to 763 in 5062594
cuvs/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Lines 644 to 649 in 5062594
The problem is that the (max) array length (=
MAX_ITOPK + MAX_CANDIDATES
) is not always the power of two.These bitonic sorts are called even if no elements are filtered out unless
cuvs::neighbors::filtering::none_sample_filter
is specified as the filter, so #472 occurs.Fix
This PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer.