Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-CTA algorithm #492

Open
wants to merge 20 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6223fd2
[Improved Multi-CTA algo] Address low recall issue of multi-CTA algo …
anaruse Sep 26, 2024
8ff6991
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse Dec 5, 2024
37e26c1
fix style
tfeher Dec 5, 2024
3665d45
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse Dec 5, 2024
018e792
Merge branch 'branch-25.02' into improved_multi_cta_algo
achirkin Dec 9, 2024
ab1130b
Check if CAGRA search returns enough valid indices during add_nodes
achirkin Dec 9, 2024
bedd224
Resolving various issues with the new multi-CTA algorithm
anaruse Dec 20, 2024
ea8c273
Add comments in add_nodes.cuh
anaruse Dec 23, 2024
5025481
Limit tht number of warnings output
anaruse Dec 23, 2024
b61126a
Avoid invalid results in search results as much as possible
anaruse Dec 25, 2024
588bd0c
Improve the accuracy of the new multi-CTA algo by revising the usase …
anaruse Dec 30, 2024
228a1ae
Reduce the number of shared memory access
anaruse Jan 6, 2025
776f2f5
Remove unused code
anaruse Jan 8, 2025
9d262f7
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 8, 2025
192c0a9
Update cpp/src/neighbors/detail/cagra/device_common.hpp
anaruse Jan 10, 2025
d19a6c4
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 16, 2025
b5c31b3
Merge branch 'branch-25.02' into improved_multi_cta_algo
anaruse Jan 17, 2025
81e4b39
Fixed data type issues
anaruse Jan 17, 2025
cdc4bc4
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 23, 2025
dd371dc
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ void add_node_core(
raft::resource::get_cuda_stream(handle));
raft::resource::sync_stream(handle);

// Check search results
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
(uint64_t)vec_i,
(uint64_t)invalid_edges,
(uint64_t)degree,
(uint64_t)base_degree);
anaruse marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Step 2: rank-based reordering
#pragma omp parallel
{
Expand All @@ -136,9 +153,13 @@ void add_node_core(
for (std::uint32_t i = 0; i < base_degree; i++) {
std::uint32_t detourable_node_count = 0;
const auto a_id = host_neighbor_indices(vec_i, i);
if (a_id >= idx.size()) {
detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1);
anaruse marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
for (std::uint32_t j = 0; j < i; j++) {
const auto b0_id = host_neighbor_indices(vec_i, j);
assert(b0_id < idx.size());
if (b0_id >= idx.size()) { continue; }
for (std::uint32_t k = 0; k < degree; k++) {
const auto b1_id = updated_graph(b0_id, k);
if (a_id == b1_id) {
Expand All @@ -149,6 +170,7 @@ void add_node_core(
}
detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count);
}

std::sort(detourable_node_count_list.begin(),
detourable_node_count_list.end(),
[&](const std::pair<IdxT, std::size_t> a, const std::pair<IdxT, std::size_t> b) {
Expand All @@ -170,13 +192,18 @@ void add_node_core(
const auto target_new_node_id = old_size + batch.offset() + vec_i;
for (std::size_t i = 0; i < num_rev_edges; i++) {
const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i);

if (target_node_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id);
}
IdxT replace_id = new_size;
IdxT replace_id_j = 0;
std::size_t replace_num_incoming_edges = 0;
for (std::int32_t j = degree - 1; j >= static_cast<std::int32_t>(rev_edge_search_range);
j--) {
const auto neighbor_id = updated_graph(target_node_id, j);
const auto neighbor_id = updated_graph(target_node_id, j);
if (neighbor_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id);
}
const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id);
if (num_incoming_edges > replace_num_incoming_edges) {
// Check duplication
Expand All @@ -195,10 +222,6 @@ void add_node_core(
replace_id_j = j;
}
}
if (replace_id >= new_size) {
std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id);
return;
}
updated_graph(target_node_id, replace_id_j) = target_new_node_id;
rev_edges[i] = replace_id;
}
Expand All @@ -210,13 +233,15 @@ void add_node_core(
const auto rank_based_list_ptr =
updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree;
const auto rev_edges_return_list_ptr = rev_edges.data();
while (num_add < degree) {
while ((num_add < degree) &&
((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) {
const auto node_list_ptr =
interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr;
auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i;
const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges;
for (; node_list_index < max_node_list_index; node_list_index++) {
const auto candidate = node_list_ptr[node_list_index];
if (candidate >= new_size) { continue; }
// Check duplication
bool dup = false;
for (std::uint32_t j = 0; j < num_add; j++) {
Expand All @@ -233,6 +258,12 @@ void add_node_core(
}
interleave_switch = 1 - interleave_switch;
}
if (num_add < degree) {
RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)",
target_new_node_id,
num_add,
degree);
}
for (std::uint32_t i = 0; i < degree; i++) {
updated_graph(target_new_node_id, i) = temp[i];
}
Expand All @@ -248,7 +279,9 @@ void add_graph_nodes(
raft::host_matrix_view<IdxT, std::int64_t> updated_graph_view,
const cagra::extend_params& params)
{
assert(input_updated_dataset_view.extent(0) >= index.size());
if (input_updated_dataset_view.extent(0) < index.size()) {
RAFT_FAIL("Updated dataset must be not smaller than the previous index state.");
}

const std::size_t initial_dataset_size = index.size();
const std::size_t new_dataset_size = input_updated_dataset_view.extent(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
39 changes: 28 additions & 11 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by otehrs.
anaruse marked this conversation as resolved.
Show resolved Hide resolved
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}
Expand All @@ -168,13 +177,15 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = raft::upper_bound<IndexT>();
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
Expand All @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
87 changes: 73 additions & 14 deletions cpp/src/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cstdint>

#define HASHMAP_LINEAR_PROBING

// #pragma GCC diagnostic push
// #pragma GCC diagnostic ignored
// #pragma GCC diagnostic pop
Expand All @@ -38,19 +40,19 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table,
{
if (threadIdx.x < FIRST_TID) return;
for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) {
table[i] = utils::get_max_value<IdxT>();
table[i] = ~static_cast<IdxT>(0);
}
}

template <class IdxT>
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
{
// Open addressing is used for collision resolution
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#if 1
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
Expand All @@ -59,32 +61,89 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
uint32_t index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT old = atomicCAS(&table[index], ~static_cast<IdxT>(0), key);
if (old == ~static_cast<IdxT>(0)) {
const IdxT old = atomicCAS(&table[index], hashval_empty, key);
if (old == hashval_empty) {
return 1;
} else if (old == key) {
return 0;
} else if (SUPPORT_REMOVE) {
// Checks if this key has been removed before.
const uint32_t old = atomicCAS(&table[index], removed_key, key);
if (old == removed_key) {
return 1;
} else if (old == key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <unsigned TEAM_SIZE, class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key)
{
IdxT ret = 0;
if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); }
for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) {
ret |= __shfl_xor_sync(0xffffffff, ret, offset);
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT val = table[index];
if (val == key) {
return 1;
} else if (val == hashval_empty) {
return 0;
} else if (SUPPORT_REMOVE) {
// Check if this key has been removed.
if (val == removed_key) { return 0; }
}
index = (index + stride) & bit_mask;
}
return ret;
return 0;
}

template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
// To remove a key, set the MSB to 1.
const uint32_t old = atomicCAS(&table[index], key, removed_key);
if (old == key) {
return 1;
} else if (old == hashval_empty) {
return 0;
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t
insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key)
{
Expand Down
Loading
Loading