Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cagra_hnsw serialization when dataset is not part of index #591

Merged
merged 9 commits into from
Jan 30, 2025
9 changes: 8 additions & 1 deletion cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

#include "cuvs_cagra_wrapper.h"
#include <cuvs/neighbors/hnsw.hpp>
#include <raft/core/logger.hpp>

#include <chrono>
#include <memory>

namespace cuvs::bench {
Expand Down Expand Up @@ -90,8 +92,13 @@ void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
auto host_dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
auto opt_dataset_view =
std::optional<raft::host_matrix_view<const T, int64_t>>(std::move(host_dataset_view));
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
const auto start_clock = std::chrono::system_clock::now();
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
handle_, build_param_.hnsw_index_params, *cagra_index, opt_dataset_view);
int time =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - start_clock)
.count();
RAFT_LOG_DEBUG("Graph saved to HNSW format in %d:%d min", time / 60, time % 60);
}

template <typename T, typename IdxT>
Expand Down
54 changes: 36 additions & 18 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1601,9 +1601,12 @@ void deserialize(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1630,9 +1633,12 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<float, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Write the CAGRA built index as a base layer HNSW index to an output stream
Expand All @@ -1658,9 +1664,12 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1687,9 +1696,12 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Write the CAGRA built index as a base layer HNSW index to an output stream
Expand All @@ -1715,9 +1727,12 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* Save a CAGRA build index in hnswlib base-layer-only serialized format
Expand All @@ -1744,9 +1759,12 @@ void serialize_to_hnswlib(raft::resources const& handle,
* @param[in] index CAGRA index
*
*/
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index);
void serialize_to_hnswlib(
raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset =
std::nullopt);

/**
* @}
Expand Down
95 changes: 50 additions & 45 deletions cpp/src/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,56 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, filename, index, include_dataset); \
}; \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
}; \
void serialize(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, os, index, include_dataset); \
} \
\
void deserialize(raft::resources const& handle, \
std::istream& is, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
} \
\
void serialize_to_hnswlib(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>(handle, os, index); \
} \
\
void serialize_to_hnswlib(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, filename, index); \
#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, filename, index, include_dataset); \
}; \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, filename, index); \
}; \
void serialize(raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize<DTYPE, uint32_t>( \
handle, os, index, include_dataset); \
} \
\
void deserialize(raft::resources const& handle, \
std::istream& is, \
cuvs::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
cuvs::neighbors::cagra::detail::deserialize<DTYPE, uint32_t>(handle, is, index); \
} \
\
void serialize_to_hnswlib( \
raft::resources const& handle, \
std::ostream& os, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, os, index, dataset); \
} \
\
void serialize_to_hnswlib( \
raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::cagra::index<DTYPE, uint32_t>& index, \
std::optional<raft::host_matrix_view<const DTYPE, int64_t, raft::row_major>> dataset) \
{ \
cuvs::neighbors::cagra::detail::serialize_to_hnswlib<DTYPE, uint32_t>( \
handle, filename, index, dataset); \
}

} // namespace cuvs::neighbors::cagra
Loading
Loading