Skip to content

Commit

Permalink
Fix cuvs::neighbors::nn_descent::build
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Jan 21, 2025
1 parent 238ff18 commit 6600a33
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
6 changes: 4 additions & 2 deletions cpp/src/neighbors/nn_descent_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
}; \
Expand All @@ -47,7 +48,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/neighbors/nn_descent_half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand All @@ -48,7 +49,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/neighbors/nn_descent_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand All @@ -48,7 +49,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/neighbors/nn_descent_uint8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand All @@ -48,7 +49,8 @@ namespace cuvs::neighbors::nn_descent {
} else { \
std::optional<raft::device_matrix_view<float, int64_t, raft::row_major>> distances = \
std::nullopt; \
cuvs::neighbors::nn_descent::index<IdxT> idx{handle, graph.value(), distances}; \
cuvs::neighbors::nn_descent::index<IdxT> idx{ \
handle, graph.value(), distances, params.metric}; \
cuvs::neighbors::nn_descent::build<T, IdxT>(handle, params, dataset, idx); \
return idx; \
} \
Expand Down

0 comments on commit 6600a33

Please sign in to comment.