Skip to content

Commit

Permalink
Avoid invalid results in search results as much as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Dec 24, 2024
1 parent 5025481 commit 5a3519a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
7 changes: 4 additions & 3 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ void add_node_core(
raft::resource::sync_stream(handle);

// Check search results
int num_warnings = 0;
constexpr int max_warnings = 3;
int num_warnings = 0;
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
if (num_warnings < 3) {
if (num_warnings < max_warnings) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
Expand All @@ -146,7 +147,7 @@ void add_node_core(
num_warnings += 1;
}
}
if (num_warnings > 0) {
if (num_warnings > max_warnings) {
RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
search_width = 1;
num_cta_per_query =
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size));
result_buffer_size = itopk_size + search_width * graph_degree;
result_buffer_size = itopk_size + (search_width * graph_degree);
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
// constexpr unsigned max_result_buffer_size = 256;
Expand Down
97 changes: 46 additions & 51 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,9 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(
}

template <unsigned MAX_ELEMENTS, class INDEX_T>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(
float* distances, // [num_elements]
INDEX_T* indices, // [num_elements]
const uint32_t num_elements,
const uint32_t num_itopk // num_itopk <= num_elements
)
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements]
INDEX_T* indices, // [num_elements]
const uint32_t num_elements)
{
const unsigned warp_id = threadIdx.x / 32;
if (warp_id > 0) { return; }
Expand Down Expand Up @@ -239,11 +236,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
extern __shared__ uint8_t smem[];

// Layout of result_buffer
// +----------------+-------------------------------+---------+
// | internal_top_k | neighbors of parent nodes | padding |
// | <itopk_size> | <search_width * graph_degree> | upto 32 |
// +----------------+-------------------------------+---------+
// |<--- result_buffer_size --->|
// +----------------+---------+-------------------------------+
// | internal_top_k | padding | neighbors of parent nodes |
// | <itopk_size> | upto 32 | <search_width * graph_degree> |
// +----------------+---------+-------------------------------+
// |<--- result_buffer_size_32 --->|
const auto result_buffer_size = itopk_size + (search_width * graph_degree);
const auto result_buffer_size_32 = raft::round_up_safe<uint32_t>(result_buffer_size, 32);
assert(result_buffer_size_32 <= MAX_ELEMENTS);
Expand Down Expand Up @@ -283,7 +280,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
device::compute_distance_to_random_nodes(result_indices_buffer,
result_distances_buffer,
*dataset_desc,
result_buffer_size,
graph_degree * search_width,
num_distilation,
rand_xor_mask,
local_seed_ptr,
Expand All @@ -301,10 +298,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
while (1) {
_CLK_START();
// Topk with bitonic sort (1st warp only)
topk_by_bitonic_sort<MAX_ELEMENTS, INDEX_T>(result_distances_buffer,
result_indices_buffer,
itopk_size + (search_width * graph_degree),
itopk_size);
topk_by_bitonic_sort<MAX_ELEMENTS, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
__syncthreads();
_CLK_REC(clk_topk);

Expand All @@ -320,21 +315,18 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
itopk_size,
local_traversed_hashmap_ptr,
traversed_hash_bitlen);
#if 0
if (parent_indices_buffer[0] == invalid_index) {
// Try again if no parent is found
move_valid_entries_to_head<INDEX_T, DISTANCE_T>(result_indices_buffer,
result_distances_buffer,
result_buffer_size_32);
pickup_next_parents<INDEX_T, DISTANCE_T>(parent_indices_buffer,
search_width,
result_indices_buffer,
result_distances_buffer,
itopk_size,
local_traversed_hashmap_ptr,
traversed_hash_bitlen);
}
#endif
if (parent_indices_buffer[0] == invalid_index) {
// Try again if no parent is found
move_valid_entries_to_head<INDEX_T, DISTANCE_T>(
result_indices_buffer, result_distances_buffer, result_buffer_size_32);
pickup_next_parents<INDEX_T, DISTANCE_T>(parent_indices_buffer,
search_width,
result_indices_buffer,
result_distances_buffer,
itopk_size,
local_traversed_hashmap_ptr,
traversed_hash_bitlen);
}
} else {
// [Other warps] Reset visited hashmap
hashmap::init<INDEX_T>(local_visited_hashmap_ptr, visited_hash_bitlen, 32);
Expand All @@ -355,31 +347,35 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
} else {
// [Other warps] Remove entries kicked out of the itopk list from the
// traversed hash table.
for (unsigned i = threadIdx.x - 32; i < search_width * graph_degree; i += blockDim.x - 32) {
INDEX_T index = result_indices_buffer[itopk_size + i];
for (unsigned i = itopk_size + threadIdx.x - 32; i < result_buffer_size_32;
i += blockDim.x - 32) {
INDEX_T index = result_indices_buffer[i];
if (index == invalid_index) { continue; }
if (index & index_msb_1_mask) {
hashmap::remove<INDEX_T>(
local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask);
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
}
}
__syncthreads();

_CLK_START();
// compute the norms between child nodes and query node
device::compute_distance_to_child_nodes(result_indices_buffer + itopk_size,
result_distances_buffer + itopk_size,
*dataset_desc,
knn_graph,
graph_degree,
local_visited_hashmap_ptr,
visited_hash_bitlen,
local_traversed_hashmap_ptr,
traversed_hash_bitlen,
parent_indices_buffer,
result_indices_buffer,
search_width);
device::compute_distance_to_child_nodes(
result_indices_buffer + result_buffer_size_32 - graph_degree * search_width,
result_distances_buffer + result_buffer_size_32 - graph_degree * search_width,
*dataset_desc,
knn_graph,
graph_degree,
local_visited_hashmap_ptr,
visited_hash_bitlen,
local_traversed_hashmap_ptr,
traversed_hash_bitlen,
parent_indices_buffer,
result_indices_buffer,
search_width);
__syncthreads();
_CLK_REC(clk_compute_distance);

Expand Down Expand Up @@ -428,7 +424,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
if (index & index_msb_1_mask) {
is_valid = true;
index &= ~index_msb_1_mask;
} else if (hashmap::insert<INDEX_T, 1>(
} else if ((offset < itopk_size) &&
hashmap::insert<INDEX_T, 1>(
local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) {
// If a node that is not used as a parent can be inserted into
// the traversed hash table, it is considered a valid result.
Expand All @@ -444,15 +441,13 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
if (result_distances_ptr != nullptr) {
result_distances_ptr[k] = result_distances_buffer[i];
}
} else if ((index & index_msb_1_mask) == 0) {
// If a node that was successfully inserted in the traversed
// hash table is not output as a result, the hash table is
// restored using hash remove.
} else {
// If it is valid and registered in the traversed hash table but is
// not output as a result, it is removed from the hash table.
hashmap::remove<INDEX_T>(local_traversed_hashmap_ptr, traversed_hash_bitlen, index);
}
}
offset += __popc(mask);
if (offset >= itopk_size) break;
}
// If the number of outputs is insufficient, fill in with invalid results.
for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) {
Expand Down

0 comments on commit 5a3519a

Please sign in to comment.