diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index a3e591b74..dce35da8e 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -128,8 +128,9 @@ void build_knn_graph( cuvs::neighbors::cagra::graph_build_params::ivf_pq_params pq) { RAFT_EXPECTS(pq.build_params.metric == cuvs::distance::DistanceType::L2Expanded || - pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct, - "Currently only L2Expanded or InnerProduct metric are supported"); + pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct || + pq.build_params.metric == cuvs::distance::DistanceType::CosineExpanded, + "Currently only L2Expanded, InnerProduct and CosineExpanded metrics are supported"); uint32_t node_degree = knn_graph.extent(1); raft::common::nvtx::range fun_scope( diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 2b0c750ff..752c3deb8 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -99,6 +99,12 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( TEAM_SIZE, cuvs::distance::DistanceType::InnerProduct>( query_buffer, seed_index, valid_i); + case cuvs::distance::DistanceType::CosineExpanded: + norm2 = + dataset_desc.template compute_similarity( + query_buffer, seed_index, valid_i); break; default: break; } @@ -191,6 +197,13 @@ _RAFT_DEVICE void compute_distance_to_child_nodes( cuvs::distance::DistanceType::InnerProduct>( query_buffer, child_id, child_id != invalid_index); break; + case cuvs::distance::DistanceType::CosineExpanded: + norm2 = + dataset_desc.template compute_similarity( + query_buffer, child_id, child_id != invalid_index); + break; default: break; } @@ -275,9 +288,12 @@ struct standard_dataset_descriptor_t } template - __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, - const INDEX_T dataset_i, - const bool valid) const + std::enable_if_t + __device__ compute_similarity(const QUERY_T* const query_ptr, + const INDEX_T dataset_i, + const bool valid) const { const auto dataset_ptr = ptr + dataset_i * ld; const unsigned lane_id = threadIdx.x % TEAM_SIZE; @@ -286,7 +302,8 @@ struct standard_dataset_descriptor_t constexpr unsigned reg_nelem = raft::ceildiv(DATASET_BLOCK_DIM, TEAM_SIZE * vlen); raft::TxN_t dl_buff[reg_nelem]; - DISTANCE_T norm2 = 0; + DISTANCE_T dist = 0; + if (valid) { for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { #pragma unroll @@ -307,16 +324,71 @@ struct standard_dataset_descriptor_t // - Above the last element (dataset_dim-1), the query array is filled with zeros. // - The data buffer has to be also padded with zeros. DISTANCE_T d = query_ptr[device::swizzling(kv)]; - norm2 += dist_op( + dist += dist_op( d, cuvs::spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v])); } } } } for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + dist += __shfl_xor_sync(0xffffffff, dist, offset); + } + + return dist; + } + + template + std::enable_if_t __device__ + compute_similarity(const QUERY_T* const query_ptr, + const INDEX_T dataset_i, + const bool valid) const + { + const auto dataset_ptr = ptr + dataset_i * ld; + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + constexpr unsigned vlen = device::get_vlen(); + // #include (DATASET_BLOCK_DIM, TEAM_SIZE * vlen); + raft::TxN_t dl_buff[reg_nelem]; + + DISTANCE_T dist = 0; + DISTANCE_T norm1 = 0; + DISTANCE_T norm2 = 0; + if (valid) { + for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dim) break; + dl_buff[e].load(dataset_ptr, k); + } +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dim) break; +#pragma unroll + for (uint32_t v = 0; v < vlen; v++) { + const uint32_t kv = k + v; + // Note this loop can go above the dataset_dim for padded arrays. This is not a problem + // because: + // - Above the last element (dataset_dim-1), the query array is filled with zeros. + // - The data buffer has to be also padded with zeros. + DISTANCE_T q = query_ptr[device::swizzling(kv)]; + DISTANCE_T d = + cuvs::spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v]); + dist += dist_op(q, d); + norm1 += q * q; + norm2 += d * d; + } + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + dist += __shfl_xor_sync(0xffffffff, dist, offset); + norm1 += __shfl_xor_sync(0xffffffff, norm1, offset); norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); } - return norm2; + + return dist / (norm1 * norm2); } }; diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index 60d43bed6..4194f7c3b 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -267,6 +267,19 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); } } break; + case distance::DistanceType::CosineExpanded: { + float factor = (account_for_max_close ? -1.0 : 1.0); + if (factor != 1.0) { + raft::linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::mul_const_op{factor}, raft::cast_op{}), + stream); + } else if (needs_cast || needs_copy) { + raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + } + } break; default: RAFT_FAIL("Unexpected metric."); } } diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 63ac06911..52b6bf2e4 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -385,7 +385,9 @@ inline std::vector generate_inputs() {0}, {256}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false}, {true}, {0.995}); @@ -401,7 +403,9 @@ inline std::vector generate_inputs() {0}, {64}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false}, {true}, {0.995}); @@ -417,7 +421,9 @@ inline std::vector generate_inputs() {0, 4, 8, 16, 32}, // team_size {64}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false}, {false}, {0.995}); @@ -434,7 +440,9 @@ inline std::vector generate_inputs() {0}, // team_size {32, 64, 128, 256, 512, 768}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false}, {true}, {0.995}); @@ -469,7 +477,9 @@ inline std::vector generate_inputs() {0}, {64}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false}, {true}, {0.6}); // don't demand high recall without refinement @@ -497,7 +507,9 @@ inline std::vector generate_inputs() {0}, // team_size {64}, {1}, - {cuvs::distance::DistanceType::L2Expanded}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::CosineExpanded}, {false, true}, {false}, {0.99},