Skip to content

Commit

Permalink
Reduce the number of shared memory access
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Jan 6, 2025
1 parent 588bd0c commit 228a1ae
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 34 deletions.
28 changes: 10 additions & 18 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width,
IndexT* __restrict__ temp_indices_ptr = nullptr,
int* __restrict__ result_position = nullptr)
int* __restrict__ result_position = nullptr,
const int max_result_position = 0)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);
Expand All @@ -214,8 +214,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
}
if (STATIC_RESULT_POSITION) {
result_child_indices_ptr[i] = child_id;
} else {
temp_indices_ptr[i] = child_id;
} else if (child_id != invalid_index) {
int j = atomicSub(result_position, 1) - 1;
result_child_indices_ptr[j] = child_id;
}
}
__syncthreads();
Expand All @@ -227,11 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();
const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0;
const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0];
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) {
const bool valid_i = i < num_k;
const auto child_id =
valid_i ? (STATIC_RESULT_POSITION ? result_child_indices_ptr[i] : temp_indices_ptr[i])
: invalid_index;
const auto j = i + ofst;
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position);
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index;

// We should be calling `dataset_desc.compute_distance(..)` here as follows:
// > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index);
Expand All @@ -244,16 +245,7 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
__syncwarp();

// Store the distance
if (valid_i && lead_lane) {
if (STATIC_RESULT_POSITION) {
result_child_distances_ptr[i] = child_dist;
} else if (child_id != invalid_index) {
// Only valid results are stored in order from the back of the buffer
int j = atomicSub(result_position, 1) - 1;
result_child_indices_ptr[j] = child_id;
result_child_distances_ptr[j] = child_dist;
}
}
if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; }
}
}

Expand Down
1 change: 0 additions & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
dataset_desc.smem_ws_size_in_bytes +
(sizeof(INDEX_T) + sizeof(DISTANCE_T)) * (result_buffer_size_32) +
sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + // local_visited_hashmap_ptr
sizeof(INDEX_T) * (search_width * graph_degree) + // temp_indices_buffer
sizeof(INDEX_T) * search_width + // parent_indices_buffer
sizeof(int); // result_position
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);
Expand Down
25 changes: 10 additions & 15 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
reinterpret_cast<DISTANCE_T*>(result_indices_buffer + result_buffer_size_32);
auto* __restrict__ local_visited_hashmap_ptr =
reinterpret_cast<INDEX_T*>(result_distances_buffer + result_buffer_size_32);
auto* __restrict__ temp_indices_buffer =
reinterpret_cast<INDEX_T*>(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen));
auto* __restrict__ parent_indices_buffer =
reinterpret_cast<INDEX_T*>(temp_indices_buffer + graph_degree);
reinterpret_cast<INDEX_T*>(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen));
auto* __restrict__ result_position = reinterpret_cast<int*>(parent_indices_buffer + 1);

INDEX_T* const local_traversed_hashmap_ptr =
Expand Down Expand Up @@ -332,22 +330,19 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; }

_CLK_START();
// Restore visited hashmap by putting nodes on result buffer in it.
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) {
INDEX_T index = result_indices_buffer[i];
if (index == invalid_index) { continue; }
index &= ~index_msb_1_mask;
hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index);
}
// Remove nodes kicked out of the itopk list from the traversed hash table.
for (unsigned i = itopk_size + threadIdx.x; i < result_buffer_size_32; i += blockDim.x) {
INDEX_T index = result_indices_buffer[i];
if (index == invalid_index) { continue; }
if (index & index_msb_1_mask) {
if ((i >= itopk_size) && (index & index_msb_1_mask)) {
// Remove nodes kicked out of the itopk list from the traversed hash table.
hashmap::remove<INDEX_T>(
local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask);
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
} else {
// Restore visited hashmap by putting nodes on result buffer in it.
index &= ~index_msb_1_mask;
hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index);
}
}
// Initialize buffer for compute_distance_to_child_nodes.
Expand All @@ -368,9 +363,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
parent_indices_buffer,
result_indices_buffer,
1,
temp_indices_buffer,
result_position);
__syncthreads();
result_position,
result_buffer_size_32);
// __syncthreads();

// Check the state of the nodes in the result buffer which were not updated
// by the compute_distance_to_child_nodes above, and if it cannot be used as
Expand Down

0 comments on commit 228a1ae

Please sign in to comment.