Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix CAGRA filter #489

Merged
merged 19 commits into from
Dec 4, 2024
Merged

Conversation

enp1s0
Copy link
Member

@enp1s0 enp1s0 commented Nov 23, 2024

Ref : #472

The cause of the bug

The bitonic sort was used on an array that was not a power of 2 long. In the current search implementation, the bitonic sort is used to move the invalid elements to the end of the buffer as:

topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
top_k,
false);

topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
internal_topk,
false);

The problem is that the (max) array length (=MAX_ITOPK + MAX_CANDIDATES) is not always the power of two.
These bitonic sorts are called even if no elements are filtered out unless cuvs::neighbors::filtering::none_sample_filter is specified as the filter, so #472 occurs.

Fix

This PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer.

@enp1s0 enp1s0 requested a review from a team as a code owner November 23, 2024 15:58
@enp1s0 enp1s0 self-assigned this Nov 23, 2024
@github-actions github-actions bot added the cpp label Nov 23, 2024
@enp1s0 enp1s0 added bug Something isn't working non-breaking Introduces a non-breaking change labels Nov 23, 2024
@enp1s0 enp1s0 changed the title Fix CAGRA filter [BUG] Fix CAGRA filter Nov 23, 2024
@lowener
Copy link
Contributor

lowener commented Nov 24, 2024

Can you add a test that would prevent regression?

Copy link
Contributor

@achirkin achirkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, @enp1s0!
I'm a little bit confused with the description. Do I understand it right that this PR contains two fixes: (1) make the bitonic sort array always a power-of-two, (2) move filtered elements to the end of the topk buffer?
The big chunk of the PR addresses (2), but that should be irrelevant for #472, because in that bug no elements are filtered out.
Therefore, I think, it would be really beneficial to construct a reproducer for #472 as a test case in this PR and make sure it's fixed with the introduced change.

@achirkin
Copy link
Contributor

achirkin commented Nov 26, 2024

Also, (1) did you have a chance to check if this affects the QPS? (2) do we need a similar fix for multi-cta and multi-kernel versions of CAGRA?

@enp1s0
Copy link
Member Author

enp1s0 commented Nov 26, 2024

@achirkin, thank you for your comment, and I'm sorry for the bad PR description. I updated it.

Do I understand it right that this PR contains two fixes: (1) make the bitonic sort array always a power-of-two, (2) move filtered elements to the end of the topk buffer?

No, this PR changes the filtering process so that the bitonic sort is not used to move the invalid elements to the end of the buffer. In the current search implementation, the bitonic sort is used to move the invalid elements as:

topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
top_k,
false);

topk_by_bitonic_sort_1st<MAX_ITOPK + MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk + search_width * graph_degree,
internal_topk,
false);

The problem is that the (max) array length (=MAX_ITOPK + MAX_CANDIDATES) is not always the power of two.
The second bitonic sort is called even if no elements are filtered out unless cuvs::neighbors::filtering::none_sample_filter is specified as the filter, so #472 occurs.

Although, as you mentioned, making the bitonic sort array always a power-of-two is an alternative way to fix this issue, I didn't do it because 1) the array elements except the filtered-out nodes are already sorted, and 2) more registers are required that will not be used but required to make the bitonic sort array power-of-two.

Also, this bug is the cause of a problem in the CAGRA filtering unit test:

// TODO: setting search_params.itopk_size here breaks the filter tests, but is required for

When itop_k is not specified, the default value, 64, is used. The graph degree is also 64. Therefore, MAX_ITOPK (64) + MAX_CANDIDATES (64) equals 128, and the bitonic sort works correctly in this case. However, if itopk size is set to another value, the bitonic sort does not work.

I think, it would be really beneficial to construct a reproducer for #472 as a test case in this PR and make sure it's fixed with the introduced change.

Yes, so I reenabled the test in this PR by changing the following lines to set the itopk size correctly. (@lowener)

// TODO: setting search_params.itopk_size here breaks the filter tests, but is required for
// k>1024 skip these tests until fixed
if (ps.k >= 1024) { GTEST_SKIP(); }
// search_params.itopk_size = ps.itopk_size;

did you have a chance to check if this affects the QPS?

I measured the performance of no filtering out search (the same situation as #472 )

filtering-bug

do we need a similar fix for multi-cta and multi-kernel versions of CAGRA?

No.

  • In the case of multi-CTA, the bitonic sort for the power-of-2 array is used to move the invalid elements, so there is no need to change. (We use a bitonic sort here because the array size is relatively small (32+graph_degree), which would not increase the register usage pressure.)
  • In the case of multi-kernel, the _find_topk routine is called, and this bug is not related.

Copy link
Contributor

@achirkin achirkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @enp1s0 for the PR and the comprehensive answer. Now I understand the logic of the change it looks good to me overall.
Nevertheless, I have a few questions to the design below.

Comment on lines +675 to +678
for (unsigned i = 0; i < search_width; i++) {
move_invalid_to_end_of_list(
result_indices_buffer, result_distances_buffer, internal_topk);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand that right, that this algorithm moves one index at a time, and repeats this for each candidate in the list? That's O(search_width*(parent_list_buffer + search_width)) complexity?

Maybe we'd better do a prefix scan (to get the indices of the valid items) followed by a shift of all elements (e.g. we can use cub::WarpScan or cub::BlockScan for that)? Or do you think the performance difference would be negligible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I didn't use the scan method is that the maximum number of filtered-out elements here is search_width, which is typically 1. In our experience, increasing search_width does not help improve the search performance much, so it is likely to be a small number. So I use the specialized function for 1 here and run it multiple times if search_width is not 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the explanation. For some reason I thought we could have more than search_width new candidates each iteration.

Comment on lines 803 to 839
constexpr std::uint32_t warp_size = 32;
if (threadIdx.x < warp_size) {
std::uint32_t num_found_valid = 0;
for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk;
buffer_offset += warp_size) {
// Calculate the new buffer index
const auto src_position = buffer_offset + threadIdx.x;
const std::uint32_t is_valid_index =
(result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1;
std::uint32_t new_position;
scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position);
if (is_valid_index) {
const auto dst_position = num_found_valid + (new_position - 1);
result_indices_buffer[dst_position] = result_indices_buffer[src_position];
result_distances_buffer[dst_position] = result_distances_buffer[src_position];
}

// Calculate the largest valid position within a warp and bcast it for the next iteration
num_found_valid += new_position;
for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) {
const auto v = __shfl_xor_sync(~0u, num_found_valid, offset);
if ((threadIdx.x & offset) == 0) { num_found_valid = v; }
}

// If the enough number of items are found, do early termination
if (num_found_valid >= top_k) { break; }
}

if (num_found_valid < top_k) {
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort` is usable in
// the next step
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) {
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just realized you do here exactly what I wanted to suggest in the comment above. Could you please put this in a separate function for better readability? And then consider if it makes sense to re-use that in place of the move_invalid_to_end_of_list loop above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use this code here because the number of filtered-out nodes is unknown, while the maximum number is known in the above place. If the maximum number of filtered-out nodes is known and small, we can use more simple code that is not required to do the scan operation like above (although I didn't compare the performance experimentally).

Comment on lines +841 to +843
// If the sufficient number of valid indexes are not in the internal topk, pick up from the
// candidate list.
if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it right, that we need this because we filter the result indices buffer after we move candidates from the internal workspace, which is larger?
If so, why shouldn't we first filter the internal workspace and only then copy the results instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the code here does is 1) pick up valid nodes (=not filtered-out nodes) from the candidate list by the bitonic sort and 2) concatenate the resulting list and the itopk valid nodes list. This operation is needed when enough valid indices for topk are not obtained only from the itopk valid node list.

I think the function name topk_by_bitonic_sort is problematic because it is not a simple sort but takes a sorted list A and an unsorted list B and outputs merge_sort(A, bitonic_sort(B)). I'll change the name.

@enp1s0
Copy link
Member Author

enp1s0 commented Dec 3, 2024

@achirkin Thank you for your review! I fixed the code, so can you check it again?

Copy link
Contributor

@achirkin achirkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates! LGTM

Comment on lines +806 to +840
std::uint32_t num_found_valid = 0;
for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk;
buffer_offset += warp_size) {
// Calculate the new buffer index
const auto src_position = buffer_offset + threadIdx.x;
const std::uint32_t is_valid_index =
(result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1;
std::uint32_t new_position;
scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position);
if (is_valid_index) {
const auto dst_position = num_found_valid + (new_position - 1);
result_indices_buffer[dst_position] = result_indices_buffer[src_position];
result_distances_buffer[dst_position] = result_distances_buffer[src_position];
}

// Calculate the largest valid position within a warp and bcast it for the next iteration
num_found_valid += new_position;
for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) {
const auto v = raft::shfl_xor(num_found_valid, offset);
if ((threadIdx.x & offset) == 0) { num_found_valid = v; }
}

// If the enough number of items are found, do early termination
if (num_found_valid >= top_k) { break; }
}

if (num_found_valid < top_k) {
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is
// usable in the next step
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) {
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If if we use this routine only once, I still think it would be nice to move it out as a separate function alongside the move_invalid_to_end_of_list (so that we'd have two: move_first_invalid_to_end_of_list and move_all_invalid_to_end_of_list).
But given that we're very close to the code freeze and this PR is very important, I'd say we can postpone this grooming to the next release.

@achirkin
Copy link
Contributor

achirkin commented Dec 4, 2024

/merge

@rapids-bot rapids-bot bot merged commit acbd097 into rapidsai:branch-24.12 Dec 4, 2024
55 checks passed
@enp1s0 enp1s0 deleted the fix-cagra-filter branch December 4, 2024 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working cpp non-breaking Introduces a non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

3 participants