diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c493af488..ce925713e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -298,7 +298,6 @@ if(BUILD_SHARED_LIBS) src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu src/neighbors/mg/omp_checks.cpp - src/neighbors/mg/nccl_comm.cpp ) endif() diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index 50c1ff4db..6a6580f4f 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -18,7 +18,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_cagra_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -41,13 +41,12 @@ class cuvs_mg_cagra : public algo, public algo_gpu { }; cuvs_mg_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -69,7 +68,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -87,7 +86,7 @@ class cuvs_mg_cagra : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; float refine_ratio_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; @@ -105,7 +104,7 @@ void cuvs_mg_cagra::build(const T* dataset, size_t nrow) auto dataset_view = raft::make_host_matrix_view(dataset, nrow, dim_); - auto idx = cuvs::neighbors::mg::build(handle_, build_params, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, build_params, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -132,7 +131,7 @@ void cuvs_mg_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_mg_cagra::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -140,7 +139,7 @@ void cuvs_mg_cagra::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_cagra(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_cagra(clique_, file))); } template @@ -164,7 +163,7 @@ void cuvs_mg_cagra::search_base( raft::make_host_matrix_view(distances, batch_size, k); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 54a0d2fac..a2b91bc0a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_flat_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,11 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { }; cuvs_mg_ivf_flat(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -62,7 +62,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,7 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> @@ -85,7 +85,7 @@ void cuvs_mg_ivf_flat::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>(std::move(idx)); } @@ -105,7 +105,7 @@ void cuvs_mg_ivf_flat::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_flat::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -113,7 +113,7 @@ void cuvs_mg_ivf_flat::load(const std::string& file) { index_ = std::make_shared< cuvs::neighbors::mg::index, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_flat(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_flat(clique_, file))); } template @@ -134,7 +134,7 @@ void cuvs_mg_ivf_flat::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index 84aea7d4a..c2ce61cd8 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -19,7 +19,7 @@ #include "cuvs_ann_bench_utils.h" #include "cuvs_ivf_pq_wrapper.h" #include -#include +#include namespace cuvs::bench { using namespace cuvs::neighbors; @@ -37,11 +37,11 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { }; cuvs_mg_ivf_pq(Metric metric, int dim, const build_param& param) - : algo(metric, dim), index_params_(param) + : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); - // init nccl clique outside as to not affect benchmark - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle_); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; @@ -62,7 +62,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { - auto stream = raft::resource::get_cuda_stream(handle_); + auto stream = raft::resource::get_cuda_stream(clique_); return stream; } @@ -73,7 +73,7 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { std::unique_ptr> copy() override; private: - raft::device_resources handle_; + raft::device_resources_snmg clique_; build_param index_params_; cuvs::neighbors::mg::search_params search_params_; std::shared_ptr, T, IdxT>> index_; @@ -84,7 +84,7 @@ void cuvs_mg_ivf_pq::build(const T* dataset, size_t nrow) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dim_)); - auto idx = cuvs::neighbors::mg::build(handle_, index_params_, dataset_view); + auto idx = cuvs::neighbors::mg::build(clique_, index_params_, dataset_view); index_ = std::make_shared, T, IdxT>>( std::move(idx)); @@ -104,7 +104,7 @@ void cuvs_mg_ivf_pq::set_search_param(const search_param_base& param) template void cuvs_mg_ivf_pq::save(const std::string& file) const { - cuvs::neighbors::mg::serialize(handle_, *index_, file); + cuvs::neighbors::mg::serialize(clique_, *index_, file); } template @@ -112,7 +112,7 @@ void cuvs_mg_ivf_pq::load(const std::string& file) { index_ = std::make_shared, T, IdxT>>( - std::move(cuvs::neighbors::mg::deserialize_pq(handle_, file))); + std::move(cuvs::neighbors::mg::deserialize_pq(clique_, file))); } template @@ -133,7 +133,7 @@ void cuvs_mg_ivf_pq::search( distances, IdxT(batch_size), IdxT(k)); cuvs::neighbors::mg::search( - handle_, *index_, search_params_, queries_view, neighbors_view, distances_view); + clique_, *index_, search_params_, queries_view, neighbors_view, distances_view); } } // namespace cuvs::bench \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/mg.hpp b/cpp/include/cuvs/neighbors/mg.hpp index 4657fa8fb..86572adeb 100644 --- a/cpp/include/cuvs/neighbors/mg.hpp +++ b/cpp/include/cuvs/neighbors/mg.hpp @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -101,7 +101,7 @@ using namespace raft; template struct index { index(distribution_mode mode, int num_ranks_); - index(const raft::device_resources& handle, const std::string& filename); + index(const raft::device_resources_snmg& clique, const std::string& filename); index(const index&) = delete; index(index&&) = default; @@ -124,18 +124,18 @@ struct index { * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -146,18 +146,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -168,18 +168,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-Flat MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -190,18 +190,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, int64_t>; @@ -212,18 +212,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, int64_t>; @@ -234,18 +234,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, int64_t>; @@ -256,18 +256,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed IVF-PQ MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, int64_t>; @@ -278,18 +278,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, float, uint32_t>; @@ -300,18 +300,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, half, uint32_t>; @@ -322,18 +322,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, int8_t, uint32_t>; @@ -344,18 +344,18 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index_params configure the index building * @param[in] index_dataset a row-major matrix on host [n_rows, dim] * * @return the constructed CAGRA MG index */ -auto build(const raft::device_resources& handle, +auto build(const raft::device_resources_snmg& clique, const mg::index_params& index_params, raft::host_matrix_view index_dataset) -> index, uint8_t, uint32_t>; @@ -368,20 +368,20 @@ auto build(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -392,20 +392,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -416,20 +416,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -440,20 +440,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -464,20 +464,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -488,20 +488,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -512,20 +512,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, int64_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -536,20 +536,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, float, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -560,20 +560,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, half, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -584,20 +584,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, int8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -608,20 +608,20 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); - * cuvs::neighbors::mg::extend(handle, index, new_vectors, std::nullopt); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); + * cuvs::neighbors::mg::extend(clique, index, new_vectors, std::nullopt); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] new_vectors a row-major matrix on host [n_rows, dim] * @param[in] new_indices optional vector on host [n_rows], * `std::nullopt` means default continuous range `[0...n_rows)` * */ -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index, uint8_t, uint32_t>& index, raft::host_matrix_view new_vectors, std::optional> new_indices); @@ -634,15 +634,15 @@ void extend(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -651,7 +651,7 @@ void extend(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -665,15 +665,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -682,7 +682,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -696,15 +696,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -713,7 +713,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -727,15 +727,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -744,7 +744,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -758,15 +758,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -775,7 +775,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -789,15 +789,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -806,7 +806,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -820,15 +820,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -837,7 +837,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -851,15 +851,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -868,7 +868,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -882,15 +882,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -899,7 +899,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -913,15 +913,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -930,7 +930,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -944,15 +944,15 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * cuvs::neighbors::mg::search_params search_params; - * cuvs::neighbors::mg::search(handle, index, search_params, queries, neighbors, + * cuvs::neighbors::mg::search(clique, index, search_params, queries, neighbors, * distances); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] search_params configure the index search * @param[in] queries a row-major matrix on host [n_rows, dim] @@ -961,7 +961,7 @@ void search(const raft::device_resources& handle, * @param[in] n_rows_per_batch (optional) search batch size * */ -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const mg::search_params& search_params, raft::host_matrix_view queries, @@ -977,19 +977,19 @@ void search(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -999,19 +999,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1021,19 +1021,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1043,19 +1043,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, int64_t>& index, const std::string& filename); @@ -1065,19 +1065,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, int64_t>& index, const std::string& filename); @@ -1087,19 +1087,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, int64_t>& index, const std::string& filename); @@ -1109,19 +1109,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, int64_t>& index, const std::string& filename); @@ -1131,19 +1131,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, float, uint32_t>& index, const std::string& filename); @@ -1153,19 +1153,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, half, uint32_t>& index, const std::string& filename); @@ -1175,19 +1175,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, int8_t, uint32_t>& index, const std::string& filename); @@ -1197,19 +1197,19 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] index the pre-built index * @param[in] filename path to the file to be serialized * */ -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index, uint8_t, uint32_t>& index, const std::string& filename); @@ -1221,21 +1221,21 @@ void serialize(const raft::device_resources& handle, * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_flat(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_flat(const raft::device_resources& handle, const std::string& filename) +auto deserialize_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1244,20 +1244,20 @@ auto deserialize_flat(const raft::device_resources& handle, const std::string& f * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_pq(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_pq(const raft::device_resources& handle, const std::string& filename) +auto deserialize_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_deserialize @@ -1266,21 +1266,21 @@ auto deserialize_pq(const raft::device_resources& handle, const std::string& fil * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::mg::index_params index_params; - * auto index = cuvs::neighbors::mg::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::mg::build(clique, index_params, index_dataset); * const std::string filename = "mg_index.cuvs"; - * cuvs::neighbors::mg::serialize(handle, index, filename); - * auto new_index = cuvs::neighbors::mg::deserialize_cagra(handle, filename); + * cuvs::neighbors::mg::serialize(clique, index, filename); + * auto new_index = cuvs::neighbors::mg::deserialize_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized * */ template -auto deserialize_cagra(const raft::device_resources& handle, const std::string& filename) +auto deserialize_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \defgroup mg_cpp_distribute ANN MG local index distribution @@ -1292,21 +1292,21 @@ auto deserialize_cagra(const raft::device_resources& handle, const std::string& * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_flat::index_params index_params; - * auto index = cuvs::neighbors::ivf_flat::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_flat::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_flat(handle, filename); + * cuvs::neighbors::ivf_flat::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_flat(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_flat(const raft::device_resources& handle, const std::string& filename) +auto distribute_flat(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1316,20 +1316,20 @@ auto distribute_flat(const raft::device_resources& handle, const std::string& fi * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::ivf_pq::index_params index_params; - * auto index = cuvs::neighbors::ivf_pq::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::ivf_pq::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::ivf_pq::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_pq(handle, filename); + * cuvs::neighbors::ivf_pq::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_pq(clique, filename); * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_pq(const raft::device_resources& handle, const std::string& filename) +auto distribute_pq(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; /// \ingroup mg_cpp_distribute @@ -1339,21 +1339,21 @@ auto distribute_pq(const raft::device_resources& handle, const std::string& file * * Usage example: * @code{.cpp} - * raft::handle_t handle; + * raft::device_resources_snmg clique; * cuvs::neighbors::cagra::index_params index_params; - * auto index = cuvs::neighbors::cagra::build(handle, index_params, index_dataset); + * auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset); * const std::string filename = "local_index.cuvs"; - * cuvs::neighbors::cagra::serialize(handle, filename, index); - * auto new_index = cuvs::neighbors::mg::distribute_cagra(handle, filename); + * cuvs::neighbors::cagra::serialize(clique, filename, index); + * auto new_index = cuvs::neighbors::mg::distribute_cagra(clique, filename); * * @endcode * - * @param[in] handle + * @param[in] clique * @param[in] filename path to the file to be deserialized : a local index * */ template -auto distribute_cagra(const raft::device_resources& handle, const std::string& filename) +auto distribute_cagra(const raft::device_resources_snmg& clique, const std::string& filename) -> index, T, IdxT>; } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index af5e60545..26e81da16 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -53,27 +53,26 @@ flat_macro = """ #define CUVS_INST_MG_FLAT(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -81,60 +80,58 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_flat(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_flat(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ pq_macro = """ #define CUVS_INST_MG_PQ(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void extend(const raft::device_resources& handle, \\ + void extend(const raft::device_resources_snmg& clique, \\ index, T, IdxT>& index, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices) \\ { \\ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \\ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -142,52 +139,50 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_pq(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ cagra_macro = """ #define CUVS_INST_MG_CAGRA(T, IdxT) \\ - index, T, IdxT> build(const raft::device_resources& handle, \\ + index, T, IdxT> build(const raft::device_resources_snmg& clique, \\ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::build(handle, index, \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ return index; \\ } \\ \\ - void search(const raft::device_resources& handle, \\ + void search(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const mg::search_params& search_params, \\ raft::host_matrix_view queries, \\ @@ -195,33 +190,32 @@ raft::host_matrix_view distances, \\ int64_t n_rows_per_batch) \\ { \\ - cuvs::neighbors::mg::detail::search(handle, index, \\ + cuvs::neighbors::mg::detail::search(clique, index, \\ static_cast(&search_params), \\ queries, neighbors, distances, n_rows_per_batch); \\ } \\ \\ - void serialize(const raft::device_resources& handle, \\ + void serialize(const raft::device_resources_snmg& clique, \\ const index, T, IdxT>& index, \\ const std::string& filename) \\ { \\ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \\ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \\ } \\ \\ template<> \\ - index, T, IdxT> deserialize_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> deserialize_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(handle, filename); \\ + auto idx = index, T, IdxT>(clique, filename); \\ return idx; \\ } \\ \\ template<> \\ - index, T, IdxT> distribute_cagra(const raft::device_resources& handle, \\ + index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } """ diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index e9cdc30f6..14ffbce93 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -17,7 +17,6 @@ #pragma once #include "../detail/knn_merge_parts.cuh" -#include #include #include #include @@ -49,45 +48,39 @@ using namespace raft; // local index deserialization and distribution template -void deserialize_and_distribute(const raft::device_resources& handle, +void deserialize_and_distribute(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } // MG index deserialization template -void deserialize(const raft::device_resources& handle, +void deserialize(const raft::device_resources_snmg& clique, index& index, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); + const auto& handle = clique.set_current_device_to_root_rank(); + index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); + index.num_ranks_ = deserialize_scalar(handle, is); - index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); - index.num_ranks_ = deserialize_scalar(handle, is); - - if (index.num_ranks_ != clique.num_ranks_) { + if (index.num_ranks_ != clique.get_num_ranks()) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.num_ranks_); + clique.get_num_ranks()); } for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -95,24 +88,20 @@ void deserialize(const raft::device_resources& handle, } template -void build(const raft::device_resources& handle, +void build(const raft::device_resources_snmg& clique, index& index, const cuvs::neighbors::index_params* index_params, raft::host_matrix_view index_dataset) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - if (index.mode_ == REPLICATED) { int64_t n_rows = index_dataset.extent(0); - RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -121,18 +110,16 @@ void build(const raft::device_resources& handle, int64_t n_cols = index_dataset.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard); index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, partition); @@ -142,23 +129,19 @@ void build(const raft::device_resources& handle, } template -void extend(const raft::device_resources& handle, +void extend(const raft::device_resources_snmg& clique, index& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = new_vectors.extent(0); if (index.mode_ == REPLICATED) { - RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); + RAFT_LOG_DEBUG("REPLICATED EXTEND: %d*%drows", index.num_ranks_, n_rows); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -166,17 +149,15 @@ void extend(const raft::device_resources& handle, int64_t n_cols = new_vectors.extent(1); int64_t n_rows_per_shard = raft::ceildiv(n_rows, (int64_t)index.num_ranks_); - RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); + RAFT_LOG_DEBUG("SHARDED EXTEND: %d*%drows", index.num_ranks_, n_rows_per_shard); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; @@ -193,7 +174,7 @@ void extend(const raft::device_resources& handle, } template -void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -228,13 +209,11 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank - uint64_t batch_offset = clique.root_rank_ * part_size; + if (rank == clique.get_root_rank()) { // root rank + uint64_t batch_offset = clique.get_root_rank() * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -245,20 +224,20 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) continue; + if (from_rank == clique.get_root_rank()) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -276,14 +255,14 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); @@ -327,7 +306,7 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique, } template -void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, +void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -351,10 +330,8 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -399,13 +376,13 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); received_something = true; } @@ -416,13 +393,13 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -462,7 +439,7 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique, } template -void run_search_batch(const raft::comms::nccl_clique& clique, +void run_search_batch(const raft::device_resources_snmg& clique, const index& index, int rank, const cuvs::neighbors::search_params* search_params, @@ -475,9 +452,7 @@ void run_search_batch(const raft::comms::nccl_clique& clique, int64_t n_cols, int64_t n_neighbors) { - int dev_id = clique.device_ids_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( @@ -503,7 +478,7 @@ void run_search_batch(const raft::comms::nccl_clique& clique, } template -void search(const raft::device_resources& handle, +void search(const raft::device_resources_snmg& clique, const index& index, const cuvs::neighbors::search_params* search_params, raft::host_matrix_view queries, @@ -511,8 +486,6 @@ void search(const raft::device_resources& handle, raft::host_matrix_view distances, int64_t n_rows_per_batch) { - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - int64_t n_rows = queries.extent(0); int64_t n_cols = queries.extent(1); int64_t n_neighbors = neighbors.extent(1); @@ -542,7 +515,7 @@ void search(const raft::device_resources& handle, int64_t n_batches = raft::ceildiv(n_rows, (int64_t)n_rows_per_batch); if (n_batches <= 1) n_rows_per_batch = n_rows; - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "REPLICATED SEARCH IN LOAD BALANCER MODE: %d*%drows", n_batches, n_rows_per_batch); #pragma omp parallel for @@ -567,7 +540,7 @@ void search(const raft::device_resources& handle, n_neighbors); } } else if (search_mode == ROUND_ROBIN) { - RAFT_LOG_INFO("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); + RAFT_LOG_DEBUG("REPLICATED SEARCH IN ROUND ROBIN MODE: %d*%drows", 1, n_rows); ASSERT(n_rows <= n_rows_per_batch, "In round-robin mode, n_rows must lower or equal to n_rows_per_batch"); @@ -611,9 +584,9 @@ void search(const raft::device_resources& handle, if (n_batches <= 1) n_rows_per_batch = n_rows; if (merge_mode == MERGE_ON_ROOT_RANK) { - RAFT_LOG_INFO("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", - n_batches, - n_rows_per_batch); + RAFT_LOG_DEBUG("SHARDED SEARCH WITH MERGE_ON_ROOT_RANK MERGE MODE: %d*%drows", + n_batches, + n_rows_per_batch); sharded_search_with_direct_merge(clique, index, search_params, @@ -626,7 +599,7 @@ void search(const raft::device_resources& handle, n_neighbors, n_batches); } else if (merge_mode == TREE_MERGE) { - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "SHARDED SEARCH WITH TREE_MERGE MERGE MODE %d*%drows", n_batches, n_rows_per_batch); sharded_search_with_tree_merge(clique, index, @@ -644,23 +617,20 @@ void search(const raft::device_resources& handle, } template -void serialize(const raft::device_resources& handle, +void serialize(const raft::device_resources_snmg& clique, const index& index, const std::string& filename) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); - + const auto& handle = clique.set_current_device_to_root_rank(); serialize_scalar(handle, of, (int)index.mode_); serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } @@ -683,10 +653,10 @@ index::index(distribution_mode mode, int num_ranks_) } template -index::index(const raft::device_resources& handle, +index::index(const raft::device_resources_snmg& clique, const std::string& filename) : round_robin_counter_(std::make_shared>(0)) { - cuvs::neighbors::mg::detail::deserialize(handle, *this, filename); + cuvs::neighbors::mg::detail::deserialize(clique, *this, filename); } } // namespace cuvs::neighbors::mg diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index b11610fb4..e179a56e3 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -27,63 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index 8f76c69a3..3e369d9ac 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -27,63 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index 67b88d742..5ebf223d1 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -27,63 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index f72174923..923031b1c 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -27,63 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 4495e2527..f90f6fcfb 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 5494414a6..2eefad5d5 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index 35df2146b..9684f19d8 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index c671740e6..c71133ac4 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index b167239c6..df148620f 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 127baf8fd..afe5faa41 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 869e009a5..c725d2139 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,71 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources& handle, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - handle, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources& handle, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(handle, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - handle, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources& handle, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(handle, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources& handle, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(handle, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq(const raft::device_resources& handle, \ - const std::string& filename) \ - { \ - const raft::comms::nccl_clique& clique = raft::resource::get_nccl_clique(handle); \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(handle, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/nccl_comm.cpp b/cpp/src/neighbors/mg/nccl_comm.cpp deleted file mode 100644 index c4556957a..000000000 --- a/cpp/src/neighbors/mg/nccl_comm.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include - -namespace raft::comms { -void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ -} -} // namespace raft::comms diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index be30ca615..b4131acdb 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -20,7 +20,7 @@ #include "naive_knn.cuh" #include -#include +#include namespace cuvs::neighbors::mg { @@ -46,14 +46,14 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(handle_)), - clique_(raft::resource::get_nccl_clique(handle_)), + : clique_(), ps(::testing::TestWithParam::GetParam()), - d_index_dataset(0, stream_), - d_queries(0, stream_), + d_index_dataset(0, resource::get_cuda_stream(clique_)), + d_queries(0, resource::get_cuda_stream(clique_)), h_index_dataset(0), h_queries(0) { + clique_.set_memory_pool(80); } void testAnnMG() @@ -67,9 +67,10 @@ class AnnMGTest : public ::testing::TestWithParam { std::vector neighbors_snmg_ann_32bits(queries_size); { - rmm::device_uvector distances_ref_dev(queries_size, stream_); - rmm::device_uvector neighbors_ref_dev(queries_size, stream_); - cuvs::neighbors::naive_knn(handle_, + rmm::device_uvector distances_ref_dev(queries_size, resource::get_cuda_stream(clique_)); + rmm::device_uvector neighbors_ref_dev(queries_size, + resource::get_cuda_stream(clique_)); + cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), d_queries.data(), @@ -79,9 +80,15 @@ class AnnMGTest : public ::testing::TestWithParam { ps.dim, ps.k, ps.metric); - update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); - update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); + update_host(distances_ref.data(), + distances_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + update_host(neighbors_ref.data(), + neighbors_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } int64_t n_rows_per_search_batch = 3000; // [3000, 3000, 1000] == 7000 rows @@ -118,20 +125,20 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_flat_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_flat_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_flat(handle_, "mg_ivf_flat_index"); + cuvs::neighbors::mg::deserialize_flat(clique_, "mg_ivf_flat_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -177,20 +184,20 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_pq_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); + cuvs::neighbors::mg::serialize(clique_, index, "mg_ivf_pq_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_pq(handle_, "mg_ivf_pq_index"); + cuvs::neighbors::mg::deserialize_pq(clique_, "mg_ivf_pq_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -231,19 +238,19 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); { - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::serialize(handle_, index, "mg_cagra_index"); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::serialize(clique_, index, "mg_cagra_index"); } auto new_index = - cuvs::neighbors::mg::deserialize_cagra(handle_, "mg_cagra_index"); + cuvs::neighbors::mg::deserialize_cagra(clique_, "mg_cagra_index"); if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK) search_params.merge_mode = MERGE_ON_ROOT_RANK; else search_params.merge_mode = TREE_MERGE; cuvs::neighbors::mg::search( - handle_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -274,8 +281,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_flat::build(handle_, index_params, index_dataset); - ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset); + ivf_flat::serialize(clique_, "local_ivf_flat_index", index); } auto queries = raft::make_host_matrix_view( @@ -286,9 +293,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_flat(handle_, "local_ivf_flat_index"); + cuvs::neighbors::mg::distribute_flat(clique_, "local_ivf_flat_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -296,7 +303,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -326,8 +333,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_pq::build(handle_, index_params, index_dataset); - ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset); + ivf_pq::serialize(clique_, "local_ivf_pq_index", index); } auto queries = raft::make_host_matrix_view( @@ -338,9 +345,9 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_pq(handle_, "local_ivf_pq_index"); + cuvs::neighbors::mg::distribute_pq(clique_, "local_ivf_pq_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -348,7 +355,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref, @@ -373,8 +380,8 @@ class AnnMGTest : public ::testing::TestWithParam { { auto index_dataset = raft::make_device_matrix_view( d_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::cagra::build(handle_, index_params, index_dataset); - cuvs::neighbors::cagra::serialize(handle_, "local_cagra_index", index); + auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset); + cuvs::neighbors::cagra::serialize(clique_, "local_cagra_index", index); } auto queries = raft::make_host_matrix_view( @@ -385,10 +392,10 @@ class AnnMGTest : public ::testing::TestWithParam { distances_snmg_ann.data(), ps.num_queries, ps.k); auto distributed_index = - cuvs::neighbors::mg::distribute_cagra(handle_, "local_cagra_index"); + cuvs::neighbors::mg::distribute_cagra(clique_, "local_cagra_index"); search_params.merge_mode = TREE_MERGE; - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, distributed_index, search_params, queries, @@ -396,7 +403,7 @@ class AnnMGTest : public ::testing::TestWithParam { distances, n_rows_per_search_batch); - resource::sync_stream(handle_); + resource::sync_stream(clique_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits, @@ -432,8 +439,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -448,7 +455,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -496,8 +503,8 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); - cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); + cuvs::neighbors::mg::extend(clique_, index, index_dataset, std::nullopt); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -512,7 +519,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -556,7 +563,7 @@ class AnnMGTest : public ::testing::TestWithParam { auto small_batch_query = raft::make_host_matrix_view( h_queries.data(), ps.num_queries, ps.dim); - auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset); + auto index = cuvs::neighbors::mg::build(clique_, index_params, index_dataset); int n_parallel_searches = 16; std::vector searches_correctness(n_parallel_searches); @@ -571,7 +578,7 @@ class AnnMGTest : public ::testing::TestWithParam { load_balancer_neighbors_snmg_ann.data() + offset, ps.num_queries, ps.k); auto small_batch_distances = raft::make_host_matrix_view( load_balancer_distances_snmg_ann.data() + offset, ps.num_queries, ps.k); - cuvs::neighbors::mg::search(handle_, + cuvs::neighbors::mg::search(clique_, index, search_params, small_batch_query, @@ -602,37 +609,35 @@ class AnnMGTest : public ::testing::TestWithParam { void SetUp() override { - d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); - d_queries.resize(ps.num_queries * ps.dim, stream_); + d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_)); + d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_)); h_index_dataset.resize(ps.num_db_vecs * ps.dim); h_queries.resize(ps.num_queries * ps.dim); raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { raft::random::uniform( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); - raft::random::uniform(handle_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); + raft::random::uniform(clique_, r, d_queries.data(), d_queries.size(), DataT(0.1), DataT(2.0)); } else { raft::random::uniformInt( - handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); - raft::random::uniformInt(handle_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); + clique_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt(clique_, r, d_queries.data(), d_queries.size(), DataT(1), DataT(20)); } raft::copy(h_index_dataset.data(), d_index_dataset.data(), d_index_dataset.size(), - resource::get_cuda_stream(handle_)); + resource::get_cuda_stream(clique_)); raft::copy( - h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); + h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(clique_)); + resource::sync_stream(clique_); } void TearDown() override {} private: - raft::device_resources handle_; - rmm::cuda_stream_view stream_; - raft::comms::nccl_clique clique_; + raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset; std::vector h_queries;