Skip to content

Commit

Permalink
[Improved Multi-CTA algo] Address low recall issue of multi-CTA algo …
Browse files Browse the repository at this point in the history
…when the number of results is large

Fix some issues

Fix lower recall issue with new multi-cta algo

Removing redundant code and changing some parameters

Update cpp/src/neighbors/detail/cagra/search_plan.cuh

Co-authored-by: Tamas Bela Feher <[email protected]>

Remove an unnecessary line and satisfy clang-format
  • Loading branch information
anaruse committed Dec 5, 2024
1 parent 9fb21ad commit 6223fd2
Show file tree
Hide file tree
Showing 12 changed files with 347 additions and 205 deletions.
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
37 changes: 27 additions & 10 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by otehrs.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}
Expand All @@ -168,7 +177,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
Expand All @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
87 changes: 74 additions & 13 deletions cpp/src/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cstdint>

#define HASHMAP_LINEAR_PROBING

// #pragma GCC diagnostic push
// #pragma GCC diagnostic ignored
// #pragma GCC diagnostic pop
Expand All @@ -42,15 +44,15 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table,
}
}

template <class IdxT>
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
{
// Open addressing is used for collision resolution
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#if 1
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
Expand All @@ -59,32 +61,91 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
uint32_t index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT old = atomicCAS(&table[index], ~static_cast<IdxT>(0), key);
if (old == ~static_cast<IdxT>(0)) {
const IdxT old = atomicCAS(&table[index], hashval_empty, key);
if (old == hashval_empty) {
return 1;
} else if (old == key) {
return 0;
} else if (SUPPORT_REMOVE) {
// Checks if this key has been removed before.
const uint32_t old = atomicCAS(&table[index], removed_key, key);
if (old == removed_key) {
return 1;
} else if (old == key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <unsigned TEAM_SIZE, class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key)
{
IdxT ret = 0;
if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); }
for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) {
ret |= __shfl_xor_sync(0xffffffff, ret, offset);
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT val = table[index];
if (val == key) {
return 1;
} else if (val == hashval_empty) {
return 0;
} else if (SUPPORT_REMOVE) {
// Check if this key has been removed.
if (val == removed_key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return ret;
return 0;
}

template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
// To remove a key, set the MSB to 1.
const uint32_t old = atomicCAS(&table[index], key, removed_key);
if (old == key) {
return 1;
} else if (old == hashval_empty) {
return 0;
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t
insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key)
{
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,24 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
search_params params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
: base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk),
intermediate_indices(res),
intermediate_distances(res),
topk_workspace(res)

{
set_params(res, params);
}

void set_params(raft::resources const& res, const search_params& params)
{
constexpr unsigned muti_cta_itopk_size = 32;
this->itopk_size = muti_cta_itopk_size;
search_width = 1;
constexpr unsigned multi_cta_itopk_size = 32;
this->itopk_size = multi_cta_itopk_size;
search_width = 1;
num_cta_per_query =
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size));
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size));
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand All @@ -128,7 +128,8 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_

smem_size = dataset_desc.smem_ws_size_in_bytes +
(sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 +
sizeof(uint32_t) * search_width + sizeof(uint32_t);
sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) +
sizeof(INDEX_T) * search_width;
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);

//
Expand Down Expand Up @@ -222,6 +223,7 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
thread_block_size,
result_buffer_size,
smem_size,
small_hash_bitlen,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
uint32_t small_hash_bitlen, \
int64_t hash_bitlen, \
IndexT* hashmap_ptr, \
uint32_t num_cta_per_query, \
Expand Down
Loading

0 comments on commit 6223fd2

Please sign in to comment.