Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inner Product for CAGRA-Q #458

Open
wants to merge 25 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ee60d95
first commit
tarang-jain Nov 10, 2024
205f84b
instantiations
tarang-jain Nov 11, 2024
7d16801
add new headers to CMakeLists.txt
tarang-jain Nov 12, 2024
6e9c77d
pytest
tarang-jain Nov 12, 2024
c58ade5
style
tarang-jain Nov 12, 2024
5da3fd9
separate header for distance_op
tarang-jain Nov 13, 2024
1bdb38e
update metric in vpq_predict
tarang-jain Nov 13, 2024
408e019
Merge branch 'cagraq-ip' of https://github.com/tarang-jain/cuvs into …
Nov 14, 2024
948bc13
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
Nov 18, 2024
078195d
Merge branch 'branch-24.12' into cagraq-ip
cjnolet Nov 18, 2024
b09a38e
Merge branch 'cagraq-ip' of https://github.com/tarang-jain/cuvs into …
Nov 19, 2024
a560511
rm debug comments
Nov 19, 2024
aa2dde6
dist_op for float32
Nov 19, 2024
e941ce0
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
Nov 19, 2024
5980c54
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
Dec 2, 2024
3328035
debug with cuvs_bench
Dec 2, 2024
37f294b
rm debug statements
tarang-jain Dec 6, 2024
16aa2d8
merge upstream
tarang-jain Jan 10, 2025
d58f218
0,6
tarang-jain Jan 10, 2025
253be4c
cleanup
tarang-jain Jan 10, 2025
ea577d5
Merge branch 'branch-25.02' into cagraq-ip
tarang-jain Jan 11, 2025
b63d224
testing
tarang-jain Jan 12, 2025
1a85aee
Merge branch 'cagraq-ip' of https://github.com/tarang-jain/cuvs into …
tarang-jain Jan 12, 2025
6e777c7
Merge branch 'branch-25.02' into cagraq-ip
tarang-jain Jan 14, 2025
8535ecd
Merge branch 'branch-25.02' into cagraq-ip
cjnolet Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,30 @@ if(BUILD_SHARED_LIBS)
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_uint8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_uint8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim128_t8_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim128_t8_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim256_t16_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim256_t16_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim512_t32_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_float_uint32_dim512_t32_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim128_t8_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim128_t8_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim256_t16_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim256_t16_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim512_t32_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_half_uint32_dim512_t32_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim128_t8_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim128_t8_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim256_t16_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim256_t16_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim512_t32_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_int8_uint32_dim512_t32_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim128_t8_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim128_t8_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim256_t16_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim256_t16_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim512_t32_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_InnerProduct_uint8_uint32_dim512_t32_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_2subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim128_t8_8pq_4subd_half.cu
src/neighbors/detail/cagra/compute_distance_vpq_L2Expanded_float_uint32_dim256_t16_8pq_2subd_half.cu
Expand Down
34 changes: 17 additions & 17 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,24 @@

namespace cuvs::neighbors {

/**
* @defgroup neighbors_index Approximate Nearest Neighbors Types
* @{
*/

/** The base for approximate KNN index structures. */
struct index {};

/** The base for KNN index parameters. */
struct index_params {
/** Distance type. */
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
/** The argument used by some distance metrics. */
float metric_arg = 2.0f;
};

/** Parameters for VPQ compression. */
struct vpq_params {
struct vpq_params : index_params {
/**
* The bit length of the vector element after compression by PQ.
*
Expand Down Expand Up @@ -77,22 +93,6 @@ struct vpq_params {
double pq_kmeans_trainset_fraction = 0;
};

/**
* @defgroup neighbors_index Approximate Nearest Neighbors Types
* @{
*/

/** The base for approximate KNN index structures. */
struct index {};

/** The base for KNN index parameters. */
struct index_params {
/** Distance type. */
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
/** The argument used by some distance metrics. */
float metric_arg = 2.0f;
};

struct search_params {};

/** @} */ // end group neighbors_index
Expand Down
10 changes: 7 additions & 3 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,19 @@ index<T, IdxT> build(

// Construct an index from dataset and optimized knn graph.
if (params.compression.has_value()) {
RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded,
"VPQ compression is only supported with L2Expanded distance mertric");
RAFT_EXPECTS(
params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct,
"VPQ compression is only supported with L2Expanded and InnerProduct distance mertric");
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
auto compression_params = *params.compression;
compression_params.metric = params.metric;
idx.update_dataset(
res,
// TODO: hardcoding codebook math to `half`, we can do runtime dispatching later
cuvs::neighbors::vpq_build<decltype(dataset), half, int64_t>(
res, *params.compression, dataset));
res, compression_params, dataset));

return idx;
}
Expand Down
Loading
Loading