From ee98593fa758555820ad35c25c4787565dea0060 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 8 Nov 2024 18:23:28 +0100 Subject: [PATCH 1/8] Account for RAFT update --- cpp/CMakeLists.txt | 1 - cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 2 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 2 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 2 +- cpp/src/neighbors/mg/generate_mg.py | 12 ++++++------ cpp/src/neighbors/mg/mg.cuh | 18 +++++++++--------- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu | 4 ++-- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 4 ++-- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu | 4 ++-- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 4 ++-- cpp/src/neighbors/mg/nccl_comm.cpp | 8 -------- cpp/test/neighbors/mg.cuh | 2 +- 19 files changed, 41 insertions(+), 50 deletions(-) delete mode 100644 cpp/src/neighbors/mg/nccl_comm.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c493af488..ce925713e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -298,7 +298,6 @@ if(BUILD_SHARED_LIBS) src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu src/neighbors/mg/omp_checks.cpp - src/neighbors/mg/nccl_comm.cpp ) endif() diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 50c1ff4db..cf150bf98 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -47,7 +47,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 54a0d2fac..d313bbcbc 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -41,7 +41,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { { index_params_.metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 84aea7d4a..588e4798a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -41,7 +41,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { { index_params_.metric = parse_metric_type(metric); // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index af5e60545..b9fdd0aa9 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -57,7 +57,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -105,7 +105,7 @@ index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ @@ -118,7 +118,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -166,7 +166,7 @@ index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ @@ -179,7 +179,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::build(handle, index, \\ static_cast(&index_params), \\ @@ -219,7 +219,7 @@ index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ return idx; \\ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index d3f635bc4..3679e89da 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -51,7 +51,7 @@ void deserialize_and_distribute(const raft::device_resources& handle, index& index, const std::string& filename) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -70,7 +70,7 @@ void deserialize(const raft::device_resources& handle, std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); @@ -98,7 +98,7 @@ void build(const raft::device_resources& handle, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); @@ -145,7 +145,7 @@ void extend(const raft::device_resources& handle, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { @@ -191,7 +191,7 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -325,7 +325,7 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, } template -void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -460,7 +460,7 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, } template -void run_search_batch(const raft::comms::nccl_clique& clique, +void run_search_batch(const raft::core::nccl_clique& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -509,7 +509,7 @@ void search(const raft::device_resources& handle, raft::host_matrix_view distances, int64_t n_rows_per_batch) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); @@ -649,7 +649,7 @@ void serialize(const raft::device_resources& handle, std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index b11610fb4..0f6154395 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 8f76c69a3..ad041155b 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 67b88d742..5099a0417 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index f72174923..df33190f9 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -80,7 +80,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_cagra( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 4495e2527..5fcf27301 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 5494414a6..75128ea0c 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 35df2146b..cafdbdcbb 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_flat( \ const raft::device_resources& handle, const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c671740e6..8e7a71014 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index b167239c6..7fe19a899 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 127baf8fd..ca56c90ed 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 869e009a5..ed40fa230 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace cuvs::neighbors::mg { const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ handle, \ @@ -88,7 +88,7 @@ namespace cuvs::neighbors::mg { index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ const std::string& filename) \ { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ + const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ return idx; \ diff --git a/cpp/src/neighbors/mg/nccl_comm.cpp b/cpp/src/neighbors/mg/nccl_comm.cpp deleted file mode 100644 index c4556957a..000000000 --- a/cpp/src/neighbors/mg/nccl_comm.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include - -namespace raft::comms { -void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ -} -} // namespace raft::comms diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index be30ca615..6b98e975a 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -632,7 +632,7 @@ class AnnMGTest : public ::testing::TestWithParam { private: raft::device_resources handle_; rmm::cuda_stream_view stream_; - raft::comms::nccl_clique clique_; + raft::core::nccl_clique clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries; From 3a99b406dbf71d58c02148b850524ff75248df52 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 14:25:38 +0100 Subject: [PATCH 2/8] use new device_resources_snmg --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 16 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 15 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 15 +- cpp/include/cuvs/neighbors/mg.hpp | 494 +++++++++--------- cpp/src/neighbors/mg/generate_mg.py | 74 ++- cpp/src/neighbors/mg/mg.cuh | 40 +- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 22 +- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 22 +- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 26 +- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 26 +- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 26 +- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 128 +++-- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 128 +++-- cpp/test/neighbors/mg.cuh | 84 +-- 18 files changed, 687 insertions(+), 729 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index cf150bf98..f5a394482 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -18,7 +18,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_cagra_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -41,13 +41,10 @@ class cuvs_mg_cagra : public algo, public algo_gpu { }; cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -88,6 +85,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; @@ -105,7 +103,7 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(handle_, build_params, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, build_params, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -132,7 +130,7 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -140,7 +138,7 @@ void cuvs_mg_cagra::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_cagra(clique_, file))); } template @@ -164,7 +162,7 @@ void cuvs_mg_cagra::search_base( raft::make_host_matrix_view(distances, batch_size, k); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index d313bbcbc..05e68b26b 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_flat_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,9 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { }; cuvs_mg_ivf_flat(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -74,6 +72,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> @@ -85,7 +84,7 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); } @@ -105,7 +104,7 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -113,7 +112,7 @@ void cuvs_mg_ivf_flat::load(const std::string& file) { index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_flat(clique_, file))); } template @@ -134,7 +133,7 @@ void cuvs_mg_ivf_flat::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 588e4798a..d430d27bf 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_pq_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,9 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { }; cuvs_mg_ivf_pq(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); } void build(const T* dataset, size_t nrow) final; @@ -74,6 +72,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { private: raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> index_; @@ -84,7 +83,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -104,7 +103,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -112,7 +111,7 @@ void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_pq(clique_, file))); } template @@ -133,7 +132,7 @@ void cuvs_mg_ivf_pq::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index 4657fa8fb..86572adeb 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -101,7 +101,7 @@ using namespace raft; template struct index { index(distribution_mode mode, int num_ranks_); - index(const raft::device_resources& handle, const std::string& filename); + index(const raft::device_resources_snmg& clique, const std::string& filename); index(const index&) = delete; index(index&&) = default; @@ -124,18 +124,18 @@ struct index { * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -146,18 +146,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -168,18 +168,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -190,18 +190,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -212,18 +212,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, int64_t>; @@ -234,18 +234,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -256,18 +256,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -278,18 +278,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, uint32_t>; @@ -300,18 +300,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, uint32_t>; @@ -322,18 +322,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, uint32_t>; @@ -344,18 +344,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, uint32_t>; @@ -368,20 +368,20 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -392,20 +392,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -416,20 +416,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -440,20 +440,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -464,20 +464,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -488,20 +488,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -512,20 +512,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -536,20 +536,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -560,20 +560,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -584,20 +584,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -608,20 +608,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -634,15 +634,15 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -651,7 +651,7 @@ void extend(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -665,15 +665,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -682,7 +682,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -696,15 +696,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -713,7 +713,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -727,15 +727,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -744,7 +744,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -758,15 +758,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -775,7 +775,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -789,15 +789,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -806,7 +806,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -820,15 +820,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -837,7 +837,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -851,15 +851,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -868,7 +868,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -882,15 +882,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -899,7 +899,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -913,15 +913,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -930,7 +930,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -944,15 +944,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -961,7 +961,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -977,19 +977,19 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -999,19 +999,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1021,19 +1021,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1043,19 +1043,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -1065,19 +1065,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const std::string& filename); @@ -1087,19 +1087,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1109,19 +1109,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1131,19 +1131,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const std::string& filename); @@ -1153,19 +1153,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const std::string& filename); @@ -1175,19 +1175,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const std::string& filename); @@ -1197,19 +1197,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const std::string& filename); @@ -1221,21 +1221,21 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_flat(const raft::device_resources& handle, const std::string& filename) +auto deserialize_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1244,20 +1244,20 @@ auto deserialize_flat(const raft::device_resources& handle, const std::string& f * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_pq(const raft::device_resources& handle, const std::string& filename) +auto deserialize_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1266,21 +1266,21 @@ auto deserialize_pq(const raft::device_resources& handle, const std::string& fil * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_cagra(const raft::device_resources& handle, const std::string& filename) +auto deserialize_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -1292,21 +1292,21 @@ auto deserialize_cagra(const raft::device_resources& handle, const std::string& * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(handle, filename); + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_flat(const raft::device_resources& handle, const std::string& filename) +auto distribute_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1316,20 +1316,20 @@ auto distribute_flat(const raft::device_resources& handle, const std::string& fi * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(handle, filename); + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_pq(const raft::device_resources& handle, const std::string& filename) +auto distribute_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1339,21 +1339,21 @@ auto distribute_pq(const raft::device_resources& handle, const std::string& file * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(handle, filename); + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_cagra(const raft::device_resources& handle, const std::string& filename) +auto distribute_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index b9fdd0aa9..023f5baf3 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -53,27 +53,26 @@ flat_macro = """ #define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -81,60 +80,58 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ pq_macro = """ #define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -142,52 +139,50 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ cagra_macro = """ #define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -195,33 +190,32 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 3679e89da..0e113ef72 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -17,7 +17,6 @@ #pragma once #include "../detail/knn_merge_parts.cuh" -#include #include #include #include @@ -47,11 +46,10 @@ using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::device_resources& handle, +void deserialize_and_distribute(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -63,17 +61,16 @@ void deserialize_and_distribute(const raft::device_resources& handle, // MG index deserialization template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - - index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); - index.num_ranks_ = deserialize_scalar(handle, is); + const auto& handle = clique.set_current_device_to_root_rank(); + index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); + index.num_ranks_ = deserialize_scalar(handle, is); if (index.num_ranks_ != clique.num_ranks_) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", @@ -93,13 +90,11 @@ void deserialize(const raft::device_resources& handle, } template -void build(const raft::device_resources& handle, +void build(const raft::device_resources_snmg& clique, index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); @@ -140,13 +135,11 @@ void build(const raft::device_resources& handle, } template -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); @@ -191,7 +184,7 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, +void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -325,7 +318,7 @@ void sharded_search_with_direct_merge(const raft::core::nccl_clique& clique, } template -void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, +void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -460,7 +453,7 @@ void sharded_search_with_tree_merge(const raft::core::nccl_clique& clique, } template -void run_search_batch(const raft::core::nccl_clique& clique, +void run_search_batch(const raft::device_resources_snmg& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -501,7 +494,7 @@ void run_search_batch(const raft::core::nccl_clique& clique, } template -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -509,8 +502,6 @@ void search(const raft::device_resources& handle, raft::host_matrix_view distances, int64_t n_rows_per_batch) { - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); @@ -642,15 +633,14 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index& index, const std::string& filename) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - + const auto& handle = clique.set_current_device_to_root_rank(); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); @@ -681,10 +671,10 @@ index::index(distribution_mode mode, int num_ranks_) } template -index::index(const raft::device_resources& handle, +index::index(const raft::device_resources_snmg& clique, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(handle, *this, filename); + cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); } } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index 0f6154395..c3ef3705e 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index ad041155b..ea9ec672b 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 5099a0417..aeae0f2cc 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index df33190f9..22421d6f0 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -29,21 +29,20 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_CAGRA(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -52,7 +51,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -61,28 +60,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 5fcf27301..423aa0284 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 75128ea0c..06bb7af26 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index cafdbdcbb..bbf7d96f8 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -29,29 +29,28 @@ namespace cuvs::neighbors::mg { #define CUVS_INST_MG_FLAT(T, IdxT) \ index, T, IdxT> build( \ - const raft::device_resources& handle, \ + const raft::device_resources_snmg& clique, \ const mg::index_params& index_params, \ raft::host_matrix_view index_dataset) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ cuvs::neighbors::mg::detail::build( \ - handle, \ + clique, \ index, \ static_cast(&index_params), \ index_dataset); \ return index; \ } \ \ - void extend(const raft::device_resources& handle, \ + void extend(const raft::device_resources_snmg& clique, \ index, T, IdxT>& index, \ raft::host_matrix_view new_vectors, \ std::optional> new_indices) \ { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ } \ \ - void search(const raft::device_resources& handle, \ + void search(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const mg::search_params& search_params, \ raft::host_matrix_view queries, \ @@ -60,7 +59,7 @@ namespace cuvs::neighbors::mg { int64_t n_rows_per_batch) \ { \ cuvs::neighbors::mg::detail::search( \ - handle, \ + clique, \ index, \ static_cast(&search_params), \ queries, \ @@ -69,28 +68,27 @@ namespace cuvs::neighbors::mg { n_rows_per_batch); \ } \ \ - void serialize(const raft::device_resources& handle, \ + void serialize(const raft::device_resources_snmg& clique, \ const index, T, IdxT>& index, \ const std::string& filename) \ { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ } \ \ template <> \ index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - auto idx = index, T, IdxT>(handle, filename); \ + auto idx = index, T, IdxT>(clique, filename); \ return idx; \ } \ \ template <> \ index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ + const raft::device_resources_snmg& clique, const std::string& filename) \ { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 8e7a71014..441a09e2f 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index 7fe19a899..bf6126fee 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index ca56c90ed..3921f810c 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index ed40fa230..8f4683fd7 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::core::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 6b98e975a..eb97b583c 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -20,7 +20,7 @@ #include "naive_knn.cuh" #include -#include +#include namespace cuvs::neighbors::mg { @@ -47,7 +47,7 @@ class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() : stream_(resource::get_cuda_stream(handle_)), - clique_(raft::resource::get_nccl_clique(handle_)), + clique_(), ps(::testing::TestWithParam::GetParam()), d_index_dataset(0, stream_), d_queries(0, stream_), @@ -69,7 +69,7 @@ class AnnMGTest : public ::testing::TestWithParam { { rmm::device_uvector distances_ref_dev(queries_size, stream_); rmm::device_uvector neighbors_ref_dev(queries_size, stream_); - cuvs::neighbors::naive_knn(handle_, + cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), d_queries.data(), @@ -118,19 +118,19 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(handle_, "mg_ivf_flat_index"); + cuvs::neighbors::mg::deserialize_flat(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -177,19 +177,19 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(handle_, "mg_ivf_pq_index"); + cuvs::neighbors::mg::deserialize_pq(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -231,18 +231,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(handle_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(handle_, "mg_cagra_index"); + cuvs::neighbors::mg::deserialize_cagra(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -274,8 +274,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_flat::build(handle_, index_params, index_dataset); - ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + ivf_flat::serialize(clique_, "local_ivf_flat_index", index); } auto queries = raft::make_host_matrix_view( @@ -286,9 +286,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(handle_, "local_ivf_flat_index"); + cuvs::neighbors::mg::distribute_flat(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -326,8 +326,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_pq::build(handle_, index_params, index_dataset); - ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + ivf_pq::serialize(clique_, "local_ivf_pq_index", index); } auto queries = raft::make_host_matrix_view( @@ -338,9 +338,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(handle_, "local_ivf_pq_index"); + cuvs::neighbors::mg::distribute_pq(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -373,8 +373,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::cagra::build(handle_, index_params, index_dataset); - cuvs::neighbors::cagra::serialize(handle_, "local_cagra_index", index); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, "local_cagra_index", index); } auto queries = raft::make_host_matrix_view( @@ -385,10 +385,10 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(handle_, "local_cagra_index"); + cuvs::neighbors::mg::distribute_cagra(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -432,8 +432,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -448,7 +448,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -496,8 +496,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -512,7 +512,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -556,7 +556,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -571,7 +571,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -610,12 +610,12 @@ class AnnMGTest : public ::testing::TestWithParam { raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { raft::random::uniform( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); - raft::random::uniform(handle_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); + raft::random::uniform(clique_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); } else { raft::random::uniformInt( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); - raft::random::uniformInt(handle_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt(clique_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); } raft::copy(h_index_dataset.data(), @@ -632,7 +632,7 @@ class AnnMGTest : public ::testing::TestWithParam { private: raft::device_resources handle_; rmm::cuda_stream_view stream_; - raft::core::nccl_clique clique_; + raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries; From e16b68e7bddf1738d4189c3b5586520deb341ecb Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 15:25:27 +0100 Subject: [PATCH 3/8] improved device_resources_snmg --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 2 + .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 2 + .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 2 + cpp/src/neighbors/mg/generate_mg.py | 12 +- cpp/src/neighbors/mg/mg.cuh | 96 ++++++------- .../neighbors/mg/mg_cagra_float_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 110 +++++++-------- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 110 +++++++-------- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 126 +++++++++--------- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 126 +++++++++--------- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 126 +++++++++--------- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 126 +++++++++--------- cpp/test/neighbors/mg.cuh | 1 + 17 files changed, 712 insertions(+), 725 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index f5a394482..27a0fd7ac 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -45,6 +45,8 @@ class cuvs_mg_cagra : public algo, public algo_gpu { { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 05e68b26b..5e811da33 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index d430d27bf..c4a820cad 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 023f5baf3..26e81da16 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -57,7 +57,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -104,7 +104,7 @@ index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -116,7 +116,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -163,7 +163,7 @@ index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -175,7 +175,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -214,7 +214,7 @@ index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 0e113ef72..c6812b1e1 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -51,10 +51,8 @@ void deserialize_and_distribute(const raft::device_resources_snmg& clique, const std::string& filename) { for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } @@ -72,17 +70,15 @@ void deserialize(const raft::device_resources_snmg& clique, index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); - if (index.num_ranks_ != clique.num_ranks_) { + if (index.num_ranks_ != clique.get_num_ranks()) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.num_ranks_); + clique.get_num_ranks()); } for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -102,10 +98,8 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -119,13 +113,11 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, partition); @@ -146,10 +138,8 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -161,13 +151,11 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; @@ -219,13 +207,11 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank - uint64_t batch_offset = clique.root_rank_ * part_size; + if (rank == clique.get_root_rank()) { // root rank + uint64_t batch_offset = clique.get_root_rank() * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -236,20 +222,20 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) continue; + if (from_rank == clique.get_root_rank()) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -267,14 +253,14 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); @@ -342,10 +328,8 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -390,13 +374,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); received_something = true; } @@ -407,13 +391,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -466,9 +450,7 @@ void run_search_batch(const raft::device_resources_snmg& clique, int64_t n_cols, int64_t n_neighbors) { - int dev_id = clique.device_ids_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( @@ -645,10 +627,8 @@ void serialize(const raft::device_resources_snmg& clique, serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index c3ef3705e..e179a56e3 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index ea9ec672b..3e369d9ac 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index aeae0f2cc..5ebf223d1 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 22421d6f0..923031b1c 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 423aa0284..f90f6fcfb 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 06bb7af26..2eefad5d5 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index bbf7d96f8..9684f19d8 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 441a09e2f..c71133ac4 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index bf6126fee..df148620f 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 3921f810c..afe5faa41 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 8f4683fd7..c725d2139 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index eb97b583c..f634765c9 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -54,6 +54,7 @@ class AnnMGTest : public ::testing::TestWithParam { h_index_dataset(0), h_queries(0) { + clique_.set_memory_pool(80); } void testAnnMG() From 96e69fc60db6c3848d77857eda9ff3c3a671ec00 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 17:59:00 +0100 Subject: [PATCH 4/8] switch from RAFT_LOG_INFO to RAFT_LOG_DEBUG for mg logs --- cpp/src/neighbors/mg/mg.cuh | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 54ac32cc3..14ffbce93 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -95,7 +95,7 @@ void build(const raft::device_resources_snmg& clique, { if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); - RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for @@ -110,7 +110,7 @@ void build(const raft::device_resources_snmg& clique, int64_t n_cols = index_dataset.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for @@ -136,7 +136,7 @@ void extend(const raft::device_resources_snmg& clique, { int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { - RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -149,7 +149,7 @@ void extend(const raft::device_resources_snmg& clique, int64_t n_cols = new_vectors.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { @@ -515,7 +515,7 @@ void search(const raft::device_resources_snmg& clique, int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); if (n_batches <= 1) n_rows_per_batch = n_rows; - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "REPLICATED SEARCH IN LOAD BALANCER MODE: %d*%drows", n_batches, n_rows_per_batch); #pragma omp parallel for @@ -540,7 +540,7 @@ void search(const raft::device_resources_snmg& clique, n_neighbors); } } else if (search_mode == ROUND_ROBIN) { - RAFT_LOG_INFO("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); + RAFT_LOG_DEBUG("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); ASSERT(n_rows <= n_rows_per_batch, "In round-robin mode, n_rows must lower or equal to n_rows_per_batch"); @@ -584,9 +584,9 @@ void search(const raft::device_resources_snmg& clique, if (n_batches <= 1) n_rows_per_batch = n_rows; if (merge_mode == MERGE_ON_ROOT_RANK) { - RAFT_LOG_INFO("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", - n_batches, - n_rows_per_batch); + RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", + n_batches, + n_rows_per_batch); sharded_search_with_direct_merge(clique, index, search_params, @@ -599,7 +599,7 @@ void search(const raft::device_resources_snmg& clique, n_neighbors, n_batches); } else if (merge_mode == TREE_MERGE) { - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch); sharded_search_with_tree_merge(clique, index, From 657bf9e36c96bb578e514a605b2510a5a3b95f09 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 25 Nov 2024 15:05:06 +0100 Subject: [PATCH 5/8] clique as device_resource --- .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 3 +-- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 3 +-- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 3 +-- cpp/test/neighbors/mg.cuh | 24 +++++++++---------- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 27a0fd7ac..6a6580f4f 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -68,7 +68,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -86,7 +86,6 @@ class cuvs_mg_cagra : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 5e811da33..a2b91bc0a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -62,7 +62,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,6 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index c4a820cad..c2ce61cd8 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -62,7 +62,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,6 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index f634765c9..853dc8c0e 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -46,8 +46,7 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(handle_)), - clique_(), + : stream_(resource::get_cuda_stream(clique_)), ps(::testing::TestWithParam::GetParam()), d_index_dataset(0, stream_), d_queries(0, stream_), @@ -82,7 +81,7 @@ class AnnMGTest : public ::testing::TestWithParam { ps.metric); update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); + resource::sync_stream(clique_); } int64_t n_rows_per_search_batch = 3000; // [3000, 3000, 1000] == 7000 rows @@ -132,7 +131,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -191,7 +190,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -244,7 +243,7 @@ class AnnMGTest : public ::testing::TestWithParam { search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -297,7 +296,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -349,7 +348,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -397,7 +396,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -622,16 +621,15 @@ class AnnMGTest : public ::testing::TestWithParam { raft::copy(h_index_dataset.data(), d_index_dataset.data(), d_index_dataset.size(), - resource::get_cuda_stream(handle_)); + resource::get_cuda_stream(clique_)); raft::copy( - h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); + h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } void TearDown() override {} private: - raft::device_resources handle_; rmm::cuda_stream_view stream_; raft::device_resources_snmg clique_; AnnMGInputs ps; From 1fdccd4e5c46625e8f1f467052d7b2ab4b5556e7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 26 Nov 2024 14:31:46 +0100 Subject: [PATCH 6/8] updating MG tests --- cpp/test/neighbors/mg.cuh | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 853dc8c0e..b4131acdb 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -46,10 +46,10 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(clique_)), + : clique_(), ps(::testing::TestWithParam::GetParam()), - d_index_dataset(0, stream_), - d_queries(0, stream_), + d_index_dataset(0, resource::get_cuda_stream(clique_)), + d_queries(0, resource::get_cuda_stream(clique_)), h_index_dataset(0), h_queries(0) { @@ -67,8 +67,9 @@ class AnnMGTest : public ::testing::TestWithParam { std::vector neighbors_snmg_ann_32bits(queries_size); { - rmm::device_uvector distances_ref_dev(queries_size, stream_); - rmm::device_uvector neighbors_ref_dev(queries_size, stream_); + rmm::device_uvector distances_ref_dev(queries_size, resource::get_cuda_stream(clique_)); + rmm::device_uvector neighbors_ref_dev(queries_size, + resource::get_cuda_stream(clique_)); cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), @@ -79,8 +80,14 @@ class AnnMGTest : public ::testing::TestWithParam { ps.dim, ps.k, ps.metric); - update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); - update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); + update_host(distances_ref.data(), + distances_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + update_host(neighbors_ref.data(), + neighbors_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); resource::sync_stream(clique_); } @@ -602,8 +609,8 @@ class AnnMGTest : public ::testing::TestWithParam { void SetUp() override { - d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); - d_queries.resize(ps.num_queries * ps.dim, stream_); + d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_)); + d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_)); h_index_dataset.resize(ps.num_db_vecs * ps.dim); h_queries.resize(ps.num_queries * ps.dim); @@ -630,7 +637,6 @@ class AnnMGTest : public ::testing::TestWithParam { void TearDown() override {} private: - rmm::cuda_stream_view stream_; raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; From 45a41faf2f1551ec2eccd5bfecf1b87c403a5ae8 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Jan 2025 17:53:17 +0000 Subject: [PATCH 7/8] API unification (removal of mg namespace) --- cpp/CMakeLists.txt | 1 + .../ann/src/cuvs/cuvs_mg_cagra_wrapper.h | 8 +- .../ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h | 8 +- .../ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h | 8 +- cpp/include/cuvs/neighbors/cagra.hpp | 462 ++++++ cpp/include/cuvs/neighbors/ivf_flat.hpp | 361 +++++ cpp/include/cuvs/neighbors/ivf_pq.hpp | 451 ++++++ cpp/include/cuvs/neighbors/mg.hpp | 1256 +---------------- cpp/src/neighbors/mg/generate_mg.py | 351 ++--- cpp/src/neighbors/mg/mg.cuh | 3 + .../neighbors/mg/mg_cagra_float_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_half_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_int8_t_uint32_t.cu | 122 +- .../neighbors/mg/mg_cagra_uint8_t_uint32_t.cu | 122 +- cpp/src/neighbors/mg/mg_flat_float_int64_t.cu | 138 +- .../neighbors/mg/mg_flat_int8_t_int64_t.cu | 138 +- .../neighbors/mg/mg_flat_uint8_t_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_float_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_half_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu | 138 +- cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu | 138 +- cpp/test/neighbors/mg.cuh | 132 +- 22 files changed, 2277 insertions(+), 2218 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 78862cb33..59eab62ef 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -580,6 +580,7 @@ if(BUILD_SHARED_LIBS) if(BUILD_MG_ALGOS) target_compile_definitions(cuvs PUBLIC CUVS_BUILD_MG_ALGOS) target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS) + target_compile_definitions(cuvs-cagra-search PUBLIC CUVS_BUILD_MG_ALGOS) endif() if(BUILD_CAGRA_HNSWLIB) diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 6a6580f4f..6762287e1 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -104,7 +104,7 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(clique_, build_params, dataset_view); + auto idx = cuvs::neighbors::cagra::build(clique_, build_params, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -131,7 +131,7 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::cagra::serialize(clique_, *index_, file); } template @@ -139,7 +139,7 @@ void cuvs_mg_cagra::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(clique_, file))); + std::move(cuvs::neighbors::cagra::deserialize(clique_, file))); } template @@ -162,7 +162,7 @@ void cuvs_mg_cagra::search_base( auto distances_view = raft::make_host_matrix_view(distances, batch_size, k); - cuvs::neighbors::mg::search( + cuvs::neighbors::cagra::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index a2b91bc0a..de854b508 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -85,7 +85,7 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); + auto idx = cuvs::neighbors::ivf_flat::build(clique_, index_params_, dataset_view); index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); } @@ -105,7 +105,7 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::ivf_flat::serialize(clique_, *index_, file); } template @@ -113,7 +113,7 @@ void cuvs_mg_ivf_flat::load(const std::string& file) { index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(clique_, file))); + std::move(cuvs::neighbors::ivf_flat::deserialize(clique_, file))); } template @@ -133,7 +133,7 @@ void cuvs_mg_ivf_flat::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_flat::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index c2ce61cd8..9e6e20c98 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -84,7 +84,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); + auto idx = cuvs::neighbors::ivf_pq::build(clique_, index_params_, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -104,7 +104,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(clique_, *index_, file); + cuvs::neighbors::ivf_pq::serialize(clique_, *index_, file); } template @@ -112,7 +112,7 @@ void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(clique_, file))); + std::move(cuvs::neighbors::ivf_pq::deserialize(clique_, file))); } template @@ -132,7 +132,7 @@ void cuvs_mg_ivf_pq::search( auto distances_view = raft::make_host_matrix_view( distances, IdxT(batch_size), IdxT(k)); - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_pq::search( clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index a4684ce26..f62fe7aea 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -1752,4 +1753,465 @@ void serialize_to_hnswlib(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, half, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, uint32_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed CAGRA MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, uint32_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, half, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::cagra::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, uint32_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, uint32_t>& index, + const std::string& filename); + +/// \defgroup mg_cpp_deserialize ANN MG index deserialization + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes a CAGRA multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::cagra::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::cagra::index_params index_params; + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::cagra::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index e017946d9..a9cb02de6 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -19,6 +19,7 @@ #include "common.hpp" #include #include +#include #include #include @@ -1598,6 +1599,366 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-Flat MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_flat::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_flat::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize( + const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-Flat multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_flat::deserialize(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_flat::index_params index_params; + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_flat::distribute(clique, filename); + * + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index d85753b7f..1be9a1924 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -1728,6 +1729,456 @@ void deserialize(raft::resources const& handle, * @} */ +/// \defgroup mg_cpp_index_build ANN MG index build + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, float, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, half, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, int8_t, int64_t>; + +/// \ingroup mg_cpp_index_build +/** + * @brief Builds a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index_params configure the index building + * @param[in] index_dataset a row-major matrix on host [n_rows, dim] + * + * @return the constructed IVF-PQ MG index + */ +auto build(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index_params& index_params, + raft::host_matrix_view index_dataset) + -> cuvs::neighbors::mg::index, uint8_t, int64_t>; + +/// \defgroup mg_cpp_index_extend ANN MG index extend + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, float, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, half, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, int8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \ingroup mg_cpp_index_extend +/** + * @brief Extends a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::ivf_pq::extend(clique, index, new_vectors, std::nullopt); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] new_vectors a row-major matrix on host [n_rows, dim] + * @param[in] new_indices optional vector on host [n_rows], + * `std::nullopt` means default continuous range `[0...n_rows)` + * + */ +void extend(const raft::device_resources_snmg& clique, + cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices); + +/// \defgroup mg_cpp_index_search ANN MG index search + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \ingroup mg_cpp_index_search +/** + * @brief Searches a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::search_params search_params; + * cuvs::neighbors::ivf_pq::search(clique, index, search_params, queries, neighbors, + * distances); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] search_params configure the index search + * @param[in] queries a row-major matrix on host [n_rows, dim] + * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] + * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] + * @param[in] n_rows_per_batch (optional) search batch size + * + */ +void search(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const cuvs::neighbors::mg::search_params& search_params, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); + +/// \defgroup mg_cpp_serialize ANN MG index serialization + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, float, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, half, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, int8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_serialize +/** + * @brief Serializes a multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] index the pre-built index + * @param[in] filename path to the file to be serialized + * + */ +void serialize(const raft::device_resources_snmg& clique, + const cuvs::neighbors::mg::index, uint8_t, int64_t>& index, + const std::string& filename); + +/// \ingroup mg_cpp_deserialize +/** + * @brief Deserializes an IVF-PQ multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::mg::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "mg_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::ivf_pq::deserialize(clique, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized + * + */ +template +auto deserialize(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; + +/// \defgroup mg_cpp_distribute ANN MG local index distribution + +/// \ingroup mg_cpp_distribute +/** + * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed + * multi-GPU index + * + * Usage example: + * @code{.cpp} + * raft::device_resources_snmg clique; + * cuvs::neighbors::ivf_pq::index_params index_params; + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); + * const std::string filename = "local_index.cuvs"; + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::ivf_pq::distribute(clique, filename); + * @endcode + * + * @param[in] clique a `raft::device_resources_snmg` object specifying the NCCL clique configuration + * @param[in] filename path to the file to be deserialized : a local index + * + */ +template +auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) + -> cuvs::neighbors::mg::index, T, IdxT>; namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index 86572adeb..d2be22951 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -18,16 +18,8 @@ #ifdef CUVS_BUILD_MG_ALGOS -#include -#include - -#include -#include - -#include #include -#include -#include +#include #define DEFAULT_SEARCH_BATCH_SIZE 1 << 20 @@ -92,12 +84,6 @@ struct search_params : public Upstream { cuvs::neighbors::mg::sharded_merge_mode merge_mode = TREE_MERGE; }; -} // namespace cuvs::neighbors::mg - -namespace cuvs::neighbors::mg { - -using namespace raft; - template struct index { index(distribution_mode mode, int num_ranks_); @@ -116,1246 +102,6 @@ struct index { std::shared_ptr> round_robin_counter_; }; -/// \defgroup mg_cpp_index_build ANN MG index build - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-Flat MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed IVF-PQ MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, int64_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, float, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, half, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, int8_t, uint32_t>; - -/// \ingroup mg_cpp_index_build -/** - * @brief Builds a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * @endcode - * - * @param[in] clique - * @param[in] index_params configure the index building - * @param[in] index_dataset a row-major matrix on host [n_rows, dim] - * - * @return the constructed CAGRA MG index - */ -auto build(const raft::device_resources_snmg& clique, - const mg::index_params& index_params, - raft::host_matrix_view index_dataset) - -> index, uint8_t, uint32_t>; - -/// \defgroup mg_cpp_index_extend ANN MG index extend - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, half, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, int64_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, float, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, half, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, int8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \ingroup mg_cpp_index_extend -/** - * @brief Extends a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] new_vectors a row-major matrix on host [n_rows, dim] - * @param[in] new_indices optional vector on host [n_rows], - * `std::nullopt` means default continuous range `[0...n_rows)` - * - */ -void extend(const raft::device_resources_snmg& clique, - index, uint8_t, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices); - -/// \defgroup mg_cpp_index_search ANN MG index search - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, half, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, float, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, half, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, int8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \ingroup mg_cpp_index_search -/** - * @brief Searches a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, - * distances); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] search_params configure the index search - * @param[in] queries a row-major matrix on host [n_rows, dim] - * @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors] - * @param[out] distances a row-major matrix on host [n_rows, n_neighbors] - * @param[in] n_rows_per_batch (optional) search batch size - * - */ -void search(const raft::device_resources_snmg& clique, - const index, uint8_t, uint32_t>& index, - const mg::search_params& search_params, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances, - int64_t n_rows_per_batch = DEFAULT_SEARCH_BATCH_SIZE); - -/// \defgroup mg_cpp_serialize ANN MG index serialization - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, half, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, int64_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, float, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, half, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, int8_t, uint32_t>& index, - const std::string& filename); - -/// \ingroup mg_cpp_serialize -/** - * @brief Serializes a multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * @endcode - * - * @param[in] clique - * @param[in] index the pre-built index - * @param[in] filename path to the file to be serialized - * - */ -void serialize(const raft::device_resources_snmg& clique, - const index, uint8_t, uint32_t>& index, - const std::string& filename); - -/// \defgroup mg_cpp_deserialize ANN MG index deserialization - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-Flat multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_flat(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes an IVF-PQ multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(clique, filename); - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_pq(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_deserialize -/** - * @brief Deserializes a CAGRA multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); - * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(clique, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized - * - */ -template -auto deserialize_cagra(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \defgroup mg_cpp_distribute ANN MG local index distribution - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-Flat index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_flat(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized IVF-PQ index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(clique, filename); - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_pq(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - -/// \ingroup mg_cpp_distribute -/** - * @brief Replicates a locally built and serialized CAGRA index to all GPUs to form a distributed - * multi-GPU index - * - * Usage example: - * @code{.cpp} - * raft::device_resources_snmg clique; - * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); - * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(clique, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(clique, filename); - * - * @endcode - * - * @param[in] clique - * @param[in] filename path to the file to be deserialized : a local index - * - */ -template -auto distribute_cagra(const raft::device_resources_snmg& clique, const std::string& filename) - -> index, T, IdxT>; - } // namespace cuvs::neighbors::mg #else diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 26e81da16..d7089e56c 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -43,181 +43,194 @@ #include "mg.cuh" """ -namespace_macro = """ -namespace cuvs::neighbors::mg { -""" - -footer = """ -} // namespace cuvs::neighbors::mg -""" - flat_macro = """ -#define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources_snmg& clique, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \\ +namespace cuvs::neighbors::ivf_flat { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::device_resources_snmg& clique, \\ + cuvs::neighbors::mg::index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_flat """ pq_macro = """ -#define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void extend(const raft::device_resources_snmg& clique, \\ - index, T, IdxT>& index, \\ - raft::host_matrix_view new_vectors, \\ - std::optional> new_indices) \\ - { \\ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_PQ(T, IdxT) \\ +namespace cuvs::neighbors::ivf_pq { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void extend(const raft::device_resources_snmg& clique, \\ + cuvs::neighbors::mg::index, T, IdxT>& index, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices) \\ + { \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::ivf_pq """ cagra_macro = """ -#define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ - const mg::index_params& index_params, \\ - raft::host_matrix_view index_dataset) \\ - { \\ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::build(clique, index, \\ - static_cast(&index_params), \\ - index_dataset); \\ - return index; \\ - } \\ - \\ - void search(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const mg::search_params& search_params, \\ - raft::host_matrix_view queries, \\ - raft::host_matrix_view neighbors, \\ - raft::host_matrix_view distances, \\ - int64_t n_rows_per_batch) \\ - { \\ - cuvs::neighbors::mg::detail::search(clique, index, \\ - static_cast(&search_params), \\ - queries, neighbors, distances, n_rows_per_batch); \\ - } \\ - \\ - void serialize(const raft::device_resources_snmg& clique, \\ - const index, T, IdxT>& index, \\ - const std::string& filename) \\ - { \\ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ - } \\ - \\ - template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(clique, filename); \\ - return idx; \\ - } \\ - \\ - template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ - const std::string& filename) \\ - { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ - return idx; \\ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \\ +namespace cuvs::neighbors::cagra { \\ + using namespace cuvs::neighbors::mg; \\ + \\ + cuvs::neighbors::mg::index, T, IdxT> build( \\ + const raft::device_resources_snmg& clique, \\ + const mg::index_params& index_params, \\ + raft::host_matrix_view index_dataset) \\ + { \\ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ + static_cast(&index_params), \\ + index_dataset); \\ + return index; \\ + } \\ + \\ + void search(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const mg::search_params& search_params, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbors, \\ + raft::host_matrix_view distances, \\ + int64_t n_rows_per_batch) \\ + { \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ + static_cast(&search_params), \\ + queries, neighbors, distances, n_rows_per_batch); \\ + } \\ + \\ + void serialize(const raft::device_resources_snmg& clique, \\ + const cuvs::neighbors::mg::index, T, IdxT>& index, \\ + const std::string& filename) \\ + { \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \\ + return idx; \\ + } \\ + \\ + template<> \\ + cuvs::neighbors::mg::index, T, IdxT> distribute( \\ + const raft::device_resources_snmg& clique, \\ + const std::string& filename) \\ + { \\ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ + return idx; \\ + } \\ +} // namespace cuvs::neighbors::cagra """ flat_macros = dict ( @@ -271,10 +284,8 @@ with open(path, "w") as f: f.write(header) f.write(macro['include']) - f.write(namespace_macro) f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") f.write(f"#undef {macro['name']}\n") - f.write(footer) print(f"src/neighbors/mg/{path}") diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 14ffbce93..f36ef8fa3 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -21,7 +21,10 @@ #include #include +#include #include +#include +#include #include #include diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index e179a56e3..fbba29b25 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(float, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 3e369d9ac..4633cc77c 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(half, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 5ebf223d1..4c15e09f5 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(int8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 923031b1c..8c585c299 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -25,66 +25,68 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + namespace cuvs::neighbors::cagra { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::cagra CUVS_INST_MG_CAGRA(uint8_t, uint32_t); #undef CUVS_INST_MG_CAGRA - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index f90f6fcfb..ef84cff6c 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(float, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 2eefad5d5..6e6daace7 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(int8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 9684f19d8..ab0f4fb10 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_FLAT(T, IdxT) \ + namespace cuvs::neighbors::ivf_flat { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>( \ + REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_flat CUVS_INST_MG_FLAT(uint8_t, int64_t); #undef CUVS_INST_MG_FLAT - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c71133ac4..3b5268c69 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(float, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index df148620f..e6d18bd30 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(half, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index afe5faa41..ead218677 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(int8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index c725d2139..27d36c34b 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -25,74 +25,76 @@ #include "mg.cuh" -namespace cuvs::neighbors::mg { - -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ - } +#define CUVS_INST_MG_PQ(T, IdxT) \ + namespace cuvs::neighbors::ivf_pq { \ + using namespace cuvs::neighbors::mg; \ + \ + cuvs::neighbors::mg::index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + cuvs::neighbors::mg::index, T, IdxT> index(index_params.mode, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + cuvs::neighbors::mg::index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const cuvs::neighbors::mg::index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> deserialize( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + cuvs::neighbors::mg::index, T, IdxT> distribute( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = cuvs::neighbors::mg::index, T, IdxT>(REPLICATED, \ + clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ + } \ + } // namespace cuvs::neighbors::ivf_pq CUVS_INST_MG_PQ(uint8_t, int64_t); #undef CUVS_INST_MG_PQ - -} // namespace cuvs::neighbors::mg diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index b4131acdb..32532df3f 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -19,7 +19,9 @@ #include "ann_utils.cuh" #include "naive_knn.cuh" -#include +#include +#include +#include #include namespace cuvs::neighbors::mg { @@ -125,18 +127,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_flat::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(clique_, "mg_ivf_flat_index"); + cuvs::neighbors::ivf_flat::deserialize(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_flat::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -184,18 +186,18 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::ivf_pq::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(clique_, "mg_ivf_pq_index"); + cuvs::neighbors::ivf_pq::deserialize(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::ivf_pq::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -238,17 +240,17 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(clique_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(clique_, "mg_cagra_index"); + cuvs::neighbors::cagra::deserialize(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search( + cuvs::neighbors::cagra::search( clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); resource::sync_stream(clique_); @@ -293,15 +295,15 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(clique_, "local_ivf_flat_index"); + cuvs::neighbors::ivf_flat::distribute(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_flat::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -345,15 +347,15 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(clique_, "local_ivf_pq_index"); + cuvs::neighbors::ivf_pq::distribute(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_pq::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -392,16 +394,16 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(clique_, "local_cagra_index"); + cuvs::neighbors::cagra::distribute(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(clique_, - distributed_index, - search_params, - queries, - neighbors, - distances, - n_rows_per_search_batch); + cuvs::neighbors::cagra::search(clique_, + distributed_index, + search_params, + queries, + neighbors, + distances, + n_rows_per_search_batch); resource::sync_stream(clique_); @@ -439,8 +441,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -455,13 +457,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_flat::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -503,8 +505,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); - cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -519,13 +521,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::ivf_pq::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), @@ -563,7 +565,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -578,13 +580,13 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(clique_, - index, - search_params, - small_batch_query, - small_batch_neighbors, - small_batch_distances, - n_rows_per_search_batch); + cuvs::neighbors::cagra::search(clique_, + index, + search_params, + small_batch_query, + small_batch_neighbors, + small_batch_distances, + n_rows_per_search_batch); std::vector small_batch_neighbors_vec( small_batch_neighbors.data_handle(), From db12ab92d7d2c538cee8b15a8e142f0d443fdff8 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Jan 2025 18:43:05 +0000 Subject: [PATCH 8/8] Adding CUVS_BUILD_MG_ALGOS macro back --- cpp/include/cuvs/neighbors/cagra.hpp | 9 ++++++++- cpp/include/cuvs/neighbors/ivf_flat.hpp | 9 ++++++++- cpp/include/cuvs/neighbors/ivf_pq.hpp | 10 +++++++++- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index f62fe7aea..7b311a1b5 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,10 @@ #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + #include #include @@ -1753,6 +1756,8 @@ void serialize_to_hnswlib(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2214,4 +2219,6 @@ template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; +#endif + } // namespace cuvs::neighbors::cagra diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index a9cb02de6..df77542c6 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -19,10 +19,13 @@ #include "common.hpp" #include #include -#include #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + namespace cuvs::neighbors::ivf_flat { /** * @defgroup ivf_flat_cpp_index_params IVF-Flat index build parameters @@ -1599,6 +1602,8 @@ void deserialize(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -1959,6 +1964,8 @@ template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; +#endif + namespace helpers { /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 1be9a1924..1b4fc87fe 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -19,7 +19,6 @@ #include #include -#include #include #include @@ -28,6 +27,10 @@ #include #include +#ifdef CUVS_BUILD_MG_ALGOS +#include +#endif + namespace cuvs::neighbors::ivf_pq { /** @@ -1729,6 +1732,8 @@ void deserialize(raft::resources const& handle, * @} */ +#ifdef CUVS_BUILD_MG_ALGOS + /// \defgroup mg_cpp_index_build ANN MG index build /// \ingroup mg_cpp_index_build @@ -2179,6 +2184,9 @@ auto deserialize(const raft::device_resources_snmg& clique, const std::string& f template auto distribute(const raft::device_resources_snmg& clique, const std::string& filename) -> cuvs::neighbors::mg::index, T, IdxT>; + +#endif + namespace helpers { /** * @defgroup ivf_pq_cpp_helpers IVF-PQ helper methods