Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CosineExpanded Distance Metric for CAGRA #197

Draft
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ void build_knn_graph(
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params pq)
{
RAFT_EXPECTS(pq.build_params.metric == cuvs::distance::DistanceType::L2Expanded ||
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct ||
pq.build_params.metric == cuvs::distance::DistanceType::CosineExpanded,
"Currently only L2Expanded, InnerProduct and CosineExpanded metrics are supported");

uint32_t node_degree = knn_graph.extent(1);
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
Expand Down
84 changes: 78 additions & 6 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tarang-jain, a heads up here: #296 does a major refactoring of related code; let's have a look together how we can proceed with this PR once you're back to it, ok?
I have similar performance concerns as the ones we discussed on IVF-PQ; maybe it makes sense to keep the dataset normalized for cosine distance (and reuse the inner-product code path)?
Then we can either normalize the query at the time we copy it to the shared memory (pre-processing) or divide by the query norm at the post-processing/filtering step at the end of the kernel.

Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
TEAM_SIZE,
cuvs::distance::DistanceType::InnerProduct>(
query_buffer, seed_index, valid_i);
case cuvs::distance::DistanceType::CosineExpanded:
norm2 =
dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
cuvs::distance::DistanceType::CosineExpanded>(
query_buffer, seed_index, valid_i);
break;
default: break;
}
Expand Down Expand Up @@ -191,6 +197,13 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
cuvs::distance::DistanceType::InnerProduct>(
query_buffer, child_id, child_id != invalid_index);
break;
case cuvs::distance::DistanceType::CosineExpanded:
norm2 =
dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
cuvs::distance::DistanceType::CosineExpanded>(
query_buffer, child_id, child_id != invalid_index);
break;
default: break;
}

Expand Down Expand Up @@ -275,9 +288,12 @@ struct standard_dataset_descriptor_t
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, cuvs::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
std::enable_if_t<METRIC == cuvs::distance::DistanceType::L2Expanded ||
METRIC == cuvs::distance::DistanceType::InnerProduct,
DISTANCE_T>
__device__ compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
{
const auto dataset_ptr = ptr + dataset_i * ld;
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand All @@ -286,7 +302,8 @@ struct standard_dataset_descriptor_t
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
DISTANCE_T dist = 0;

if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
Expand All @@ -307,16 +324,71 @@ struct standard_dataset_descriptor_t
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T d = query_ptr[device::swizzling(kv)];
norm2 += dist_op<DISTANCE_T, METRIC>(
dist += dist_op<DISTANCE_T, METRIC>(
d, cuvs::spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]));
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
dist += __shfl_xor_sync(0xffffffff, dist, offset);
}

return dist;
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, cuvs::distance::DistanceType METRIC>
std::enable_if_t<METRIC == cuvs::distance::DistanceType::CosineExpanded, DISTANCE_T> __device__
compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
{
const auto dataset_ptr = ptr + dataset_i * ld;
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = device::get_vlen<LOAD_T, DATA_T>();
// #include <raft/util/cuda_dev_essentials.cuh
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T dist = 0;
DISTANCE_T norm1 = 0;
DISTANCE_T norm2 = 0;
if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dim) break;
dl_buff[e].load(dataset_ptr, k);
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dim) break;
#pragma unroll
for (uint32_t v = 0; v < vlen; v++) {
const uint32_t kv = k + v;
// Note this loop can go above the dataset_dim for padded arrays. This is not a problem
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T q = query_ptr[device::swizzling(kv)];
DISTANCE_T d =
cuvs::spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
dist += dist_op<DISTANCE_T, cuvs::distance::DistanceType::InnerProduct>(q, d);
norm1 += q * q;
norm2 += d * d;
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
dist += __shfl_xor_sync(0xffffffff, dist, offset);
norm1 += __shfl_xor_sync(0xffffffff, norm1, offset);
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}
return norm2;

return dist / (norm1 * norm2);
}
};

Expand Down
13 changes: 13 additions & 0 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,19 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
raft::linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
case distance::DistanceType::CosineExpanded: {
float factor = (account_for_max_close ? -1.0 : 1.0);
if (factor != 1.0) {
raft::linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{factor}, raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast || needs_copy) {
raft::linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
default: RAFT_FAIL("Unexpected metric.");
}
}
Expand Down
24 changes: 18 additions & 6 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand All @@ -401,7 +403,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand All @@ -417,7 +421,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0, 4, 8, 16, 32}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{false},
{0.995});
Expand All @@ -434,7 +440,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0}, // team_size
{32, 64, 128, 256, 512, 768},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand Down Expand Up @@ -469,7 +477,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.6}); // don't demand high recall without refinement
Expand Down Expand Up @@ -497,7 +507,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false, true},
{false},
{0.99},
Expand Down
Loading