Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into extend-c-api
Browse files Browse the repository at this point in the history
  • Loading branch information
ajit283 authored Jan 8, 2025
2 parents 69f0f47 + 2a10353 commit 12ac03d
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 208 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ repos:
[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx|rs)$|
CMakeLists[.]txt$|
CMakeLists_standalone[.]txt$|
meta[.]yaml$|
setup[.]cfg$
meta[.]yaml$
exclude: |
(?x)
docs/source/sphinxext/github_link\.py|
Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function(find_and_configure_raft)
# Invoke CPM find_package()
#-----------------------------------------------------
rapids_cpm_find(raft ${PKG_VERSION}
GLOBAL_TARGETS raft::raft
GLOBAL_TARGETS raft::raft raft::raft_logger raft::raft_logger_impl
BUILD_EXPORT_SET cuvs-exports
INSTALL_EXPORT_SET cuvs-exports
COMPONENTS ${RAFT_COMPONENTS}
Expand Down
139 changes: 105 additions & 34 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include "common.hpp"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
Expand All @@ -28,6 +27,10 @@

namespace cuvs::neighbors::brute_force {

struct index_params : cuvs::neighbors::index_params {};

struct search_params : cuvs::neighbors::search_params {};

/**
* @defgroup bruteforce_cpp_index Bruteforce index
* @{
Expand All @@ -41,6 +44,11 @@ namespace cuvs::neighbors::brute_force {
*/
template <typename T, typename DistT = T>
struct index : cuvs::neighbors::index {
using index_params_type = brute_force::index_params;
using search_params_type = brute_force::search_params;
using index_type = int64_t;
using value_type = T;

public:
index(const index&) = delete;
index(index&&) = default;
Expand Down Expand Up @@ -181,83 +189,105 @@ struct index : cuvs::neighbors::index {
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed ivf-flat index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
/**
* @}
*/
Expand Down Expand Up @@ -286,6 +316,7 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index brute-force constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -296,13 +327,22 @@ auto build(raft::resources const& handle,
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -323,6 +363,7 @@ void search(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index ivf-flat constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -332,18 +373,28 @@ void search(raft::resources const& handle,
* given
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -353,18 +404,28 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -374,12 +435,21 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @}
*/
Expand Down Expand Up @@ -472,6 +542,7 @@ struct sparse_search_params {
* @brief Search the sparse bruteforce index for nearest neighbors
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index Sparse brute-force constructed index
* @param[in] queries a sparse CSR matrix on the device to query
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand Down
Loading

0 comments on commit 12ac03d

Please sign in to comment.