Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change brute_force api to match ivf*/cagra #536

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 105 additions & 34 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include "common.hpp"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
Expand All @@ -28,6 +27,10 @@

namespace cuvs::neighbors::brute_force {

struct index_params : cuvs::neighbors::index_params {};

struct search_params : cuvs::neighbors::search_params {};

/**
* @defgroup bruteforce_cpp_index Bruteforce index
* @{
Expand All @@ -41,6 +44,11 @@ namespace cuvs::neighbors::brute_force {
*/
template <typename T, typename DistT = T>
struct index : cuvs::neighbors::index {
using index_params_type = brute_force::index_params;
using search_params_type = brute_force::search_params;
using index_type = int64_t;
using value_type = T;

public:
index(const index&) = delete;
index(index&&) = default;
Expand Down Expand Up @@ -181,83 +189,105 @@ struct index : cuvs::neighbors::index {
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed ivf-flat index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
/**
* @}
*/
Expand Down Expand Up @@ -286,6 +316,7 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index brute-force constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -296,13 +327,22 @@ auto build(raft::resources const& handle,
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -323,6 +363,7 @@ void search(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index ivf-flat constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -332,18 +373,28 @@ void search(raft::resources const& handle,
* given
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -353,18 +404,28 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -374,12 +435,21 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @}
*/
Expand Down Expand Up @@ -472,6 +542,7 @@ struct sparse_search_params {
* @brief Search the sparse bruteforce index for nearest neighbors
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index Sparse brute-force constructed index
* @param[in] queries a sparse CSR matrix on the device to query
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand Down
Loading
Loading