-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve multi-CTA algorithm #492
Open
anaruse
wants to merge
14
commits into
rapidsai:branch-25.02
Choose a base branch
from
anaruse:improved_multi_cta_algo
base: branch-25.02
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6223fd2
[Improved Multi-CTA algo] Address low recall issue of multi-CTA algo …
anaruse 8ff6991
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse 37e26c1
fix style
tfeher 3665d45
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse 018e792
Merge branch 'branch-25.02' into improved_multi_cta_algo
achirkin ab1130b
Check if CAGRA search returns enough valid indices during add_nodes
achirkin bedd224
Resolving various issues with the new multi-CTA algorithm
anaruse ea8c273
Add comments in add_nodes.cuh
anaruse 5025481
Limit tht number of warnings output
anaruse b61126a
Avoid invalid results in search results as much as possible
anaruse 588bd0c
Improve the accuracy of the new multi-CTA algo by revising the usase …
anaruse 228a1ae
Reduce the number of shared memory access
anaruse 776f2f5
Remove unused code
anaruse 9d262f7
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( | |||||
const IndexT* __restrict__ seed_ptr, // [num_seeds] | ||||||
const uint32_t num_seeds, | ||||||
IndexT* __restrict__ visited_hash_ptr, | ||||||
const uint32_t hash_bitlen, | ||||||
const uint32_t visited_hash_bitlen, | ||||||
IndexT* __restrict__ traversed_hash_ptr, | ||||||
const uint32_t traversed_hash_bitlen, | ||||||
const uint32_t block_id = 0, | ||||||
const uint32_t num_blocks = 1) | ||||||
{ | ||||||
|
@@ -145,19 +147,29 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( | |||||
|
||||||
const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); | ||||||
if (valid_i && lane_id == 0) { | ||||||
if (best_index_team_local != raft::upper_bound<IndexT>() && | ||||||
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { | ||||||
result_distances_ptr[i] = best_norm2_team_local; | ||||||
result_indices_ptr[i] = best_index_team_local; | ||||||
} else { | ||||||
result_distances_ptr[i] = raft::upper_bound<DistanceT>(); | ||||||
result_indices_ptr[i] = raft::upper_bound<IndexT>(); | ||||||
if (best_index_team_local != raft::upper_bound<IndexT>()) { | ||||||
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { | ||||||
// Deactivate this entry as insertion into visited hash table has failed. | ||||||
best_norm2_team_local = raft::upper_bound<DistanceT>(); | ||||||
best_index_team_local = raft::upper_bound<IndexT>(); | ||||||
} else if ((traversed_hash_ptr != nullptr) && | ||||||
hashmap::search<IndexT, 1>( | ||||||
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { | ||||||
// Deactivate this entry as it has been already used by otehrs. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
best_norm2_team_local = raft::upper_bound<DistanceT>(); | ||||||
best_index_team_local = raft::upper_bound<IndexT>(); | ||||||
} | ||||||
} | ||||||
result_distances_ptr[i] = best_norm2_team_local; | ||||||
result_indices_ptr[i] = best_index_team_local; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
template <typename IndexT, typename DistanceT, typename DATASET_DESCRIPTOR_T> | ||||||
template <typename IndexT, | ||||||
typename DistanceT, | ||||||
typename DATASET_DESCRIPTOR_T, | ||||||
int STATIC_RESULT_POSITION = 1> | ||||||
RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( | ||||||
IndexT* __restrict__ result_child_indices_ptr, | ||||||
DistanceT* __restrict__ result_child_distances_ptr, | ||||||
|
@@ -168,13 +180,17 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( | |||||
const uint32_t knn_k, | ||||||
// hashmap | ||||||
IndexT* __restrict__ visited_hashmap_ptr, | ||||||
const uint32_t hash_bitlen, | ||||||
const uint32_t visited_hash_bitlen, | ||||||
IndexT* __restrict__ traversed_hashmap_ptr, | ||||||
const uint32_t traversed_hash_bitlen, | ||||||
const IndexT* __restrict__ parent_indices, | ||||||
const IndexT* __restrict__ internal_topk_list, | ||||||
const uint32_t search_width) | ||||||
const uint32_t search_width, | ||||||
int* __restrict__ result_position = nullptr, | ||||||
const int max_result_position = 0) | ||||||
{ | ||||||
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value; | ||||||
constexpr IndexT invalid_index = raft::upper_bound<IndexT>(); | ||||||
constexpr IndexT invalid_index = ~static_cast<IndexT>(0); | ||||||
|
||||||
// Read child indices of parents from knn graph and check if the distance | ||||||
// computaiton is necessary. | ||||||
|
@@ -186,11 +202,22 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( | |||||
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)]; | ||||||
} | ||||||
if (child_id != invalid_index) { | ||||||
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { | ||||||
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { | ||||||
// Deactivate this entry as insertion into visited hash table has failed. | ||||||
child_id = invalid_index; | ||||||
} else if ((traversed_hashmap_ptr != nullptr) && | ||||||
hashmap::search<IndexT, 1>( | ||||||
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { | ||||||
// Deactivate this entry as this has been already used by others. | ||||||
child_id = invalid_index; | ||||||
} | ||||||
} | ||||||
result_child_indices_ptr[i] = child_id; | ||||||
if (STATIC_RESULT_POSITION) { | ||||||
result_child_indices_ptr[i] = child_id; | ||||||
} else if (child_id != invalid_index) { | ||||||
int j = atomicSub(result_position, 1) - 1; | ||||||
result_child_indices_ptr[j] = child_id; | ||||||
} | ||||||
} | ||||||
__syncthreads(); | ||||||
|
||||||
|
@@ -201,9 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( | |||||
const auto compute_distance = dataset_desc.compute_distance_impl; | ||||||
const auto args = dataset_desc.args.load(); | ||||||
const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; | ||||||
const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; | ||||||
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { | ||||||
const bool valid_i = i < num_k; | ||||||
const auto child_id = valid_i ? result_child_indices_ptr[i] : invalid_index; | ||||||
const auto j = i + ofst; | ||||||
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); | ||||||
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; | ||||||
|
||||||
// We should be calling `dataset_desc.compute_distance(..)` here as follows: | ||||||
// > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index); | ||||||
|
@@ -213,9 +242,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( | |||||
(child_id != invalid_index) ? compute_distance(args, child_id) | ||||||
: (lead_lane ? raft::upper_bound<DistanceT>() : 0), | ||||||
team_size_bits); | ||||||
__syncwarp(); | ||||||
|
||||||
// Store the distance | ||||||
if (valid_i && lead_lane) { result_child_distances_ptr[i] = child_dist; } | ||||||
if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } | ||||||
} | ||||||
} | ||||||
|
||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that this check is easily understandable now, but are you sure it's ok to add it here in terms of performance? Taking into account that the same complexity loop below (step 2: rank-based reordering) apparently has been worth parallelizing using OMP?