Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/branch-25.02' into rhdong/cagra-…
Browse files Browse the repository at this point in the history
…merge
  • Loading branch information
rhdong committed Jan 30, 2025
2 parents 1eb211a + 0dd7bde commit 0fb7dff
Show file tree
Hide file tree
Showing 18 changed files with 723 additions and 362 deletions.
9 changes: 8 additions & 1 deletion cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

#include "cuvs_cagra_wrapper.h"
#include <cuvs/neighbors/hnsw.hpp>
#include <raft/core/logger.hpp>

#include <chrono>
#include <memory>

namespace cuvs::bench {
Expand Down Expand Up @@ -90,8 +92,13 @@ void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
auto host_dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
auto opt_dataset_view =
std::optional<raft::host_matrix_view<const T, int64_t>>(std::move(host_dataset_view));
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
const auto start_clock = std::chrono::system_clock::now();
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
handle_, build_param_.hnsw_index_params, *cagra_index, opt_dataset_view);
int time =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - start_clock)
.count();
RAFT_LOG_DEBUG("Graph saved to HNSW format in %d:%d min", time / 60, time % 60);
}

template <typename T, typename IdxT>
Expand Down
69 changes: 48 additions & 21 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1610,11 +1610,16 @@ void deserialize(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1639,11 +1644,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Write the CAGRA built index as a base layer HNSW index to an output stream
Expand All @@ -1667,11 +1677,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1696,11 +1711,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Write the CAGRA built index as a base layer HNSW index to an output stream
Expand All @@ -1724,11 +1744,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1753,14 +1778,16 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] dataset [optional] host array that stores the dataset, required if the index
* does not contain the dataset.
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
/**
* @}
*/
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* @defgroup cagra_cpp_index_merge CAGRA index build functions
Expand Down
95 changes: 50 additions & 45 deletions cpp/src/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,56 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, filename, index, include_dataset); \
}; \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
}; \
void serialize(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, os, index, include_dataset); \
} \
\
void deserialize(raft::resources const& handle, \
std::istream& is, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
} \
\
void serialize_to_hnswlib(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>(handle, os, index); \
} \
\
void serialize_to_hnswlib(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, filename, index); \
#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, filename, index, include_dataset); \
}; \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
}; \
void serialize(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, os, index, include_dataset); \
} \
\
void deserialize(raft::resources const& handle, \
std::istream& is, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
} \
\
void serialize_to_hnswlib( \
raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, os, index, dataset); \
} \
\
void serialize_to_hnswlib( \
raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, filename, index, dataset); \
}

} // namespace cuvs::neighbors::cagra
Loading

0 comments on commit 0fb7dff

Please sign in to comment.