From 228a1aebb73071491ffac22e42e75ef90df27aaf Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 6 Jan 2025 17:28:53 +0900 Subject: [PATCH] Reduce the number of shared memory access --- .../neighbors/detail/cagra/device_common.hpp | 28 +++++++------------ .../detail/cagra/search_multi_cta.cuh | 1 - .../cagra/search_multi_cta_kernel-inl.cuh | 25 +++++++---------- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 0e004e233..2c2c67fcd 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -186,8 +186,8 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, const uint32_t search_width, - IndexT* __restrict__ temp_indices_ptr = nullptr, - int* __restrict__ result_position = nullptr) + int* __restrict__ result_position = nullptr, + const int max_result_position = 0) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; constexpr IndexT invalid_index = ~static_cast(0); @@ -214,8 +214,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( } if (STATIC_RESULT_POSITION) { result_child_indices_ptr[i] = child_id; - } else { - temp_indices_ptr[i] = child_id; + } else if (child_id != invalid_index) { + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; } } __syncthreads(); @@ -227,11 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const auto compute_distance = dataset_desc.compute_distance_impl; const auto args = dataset_desc.args.load(); const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; + const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { - const bool valid_i = i < num_k; - const auto child_id = - valid_i ? (STATIC_RESULT_POSITION ? result_child_indices_ptr[i] : temp_indices_ptr[i]) - : invalid_index; + const auto j = i + ofst; + const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); + const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; // We should be calling `dataset_desc.compute_distance(..)` here as follows: // > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index); @@ -244,16 +245,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( __syncwarp(); // Store the distance - if (valid_i && lead_lane) { - if (STATIC_RESULT_POSITION) { - result_child_distances_ptr[i] = child_dist; - } else if (child_id != invalid_index) { - // Only valid results are stored in order from the back of the buffer - int j = atomicSub(result_position, 1) - 1; - result_child_indices_ptr[j] = child_id; - result_child_distances_ptr[j] = child_dist; - } - } + if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } } } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index d21a0407c..ac361e800 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -130,7 +130,6 @@ struct search : public search_plan_impl(result_indices_buffer + result_buffer_size_32); auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ temp_indices_buffer = - reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); auto* __restrict__ parent_indices_buffer = - reinterpret_cast(temp_indices_buffer + graph_degree); + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); INDEX_T* const local_traversed_hashmap_ptr = @@ -332,22 +330,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } _CLK_START(); - // Restore visited hashmap by putting nodes on result buffer in it. for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { INDEX_T index = result_indices_buffer[i]; if (index == invalid_index) { continue; } - index &= ~index_msb_1_mask; - hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); - } - // Remove nodes kicked out of the itopk list from the traversed hash table. - for (unsigned i = itopk_size + threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - if (index & index_msb_1_mask) { + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + // Remove nodes kicked out of the itopk list from the traversed hash table. hashmap::remove( local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); result_indices_buffer[i] = invalid_index; result_distances_buffer[i] = utils::get_max_value(); + } else { + // Restore visited hashmap by putting nodes on result buffer in it. + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); } } // Initialize buffer for compute_distance_to_child_nodes. @@ -368,9 +363,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( parent_indices_buffer, result_indices_buffer, 1, - temp_indices_buffer, - result_position); - __syncthreads(); + result_position, + result_buffer_size_32); + // __syncthreads(); // Check the state of the nodes in the result buffer which were not updated // by the compute_distance_to_child_nodes above, and if it cannot be used as