From 33698a56cfe02693e90fb4774dc949331fbc2924 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 29 Jul 2024 13:39:07 -0700 Subject: [PATCH] [Opt] introduce the `masked_matmul` to prefiltered brute force. (#251) Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/251 --- cpp/src/neighbors/detail/knn_brute_force.cuh | 39 ++++---------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index f05bebf3f..559d33cc2 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -42,7 +42,7 @@ #include #include #include -#include +#include #include #include #include @@ -636,36 +636,13 @@ void brute_force_search_filtered( rows.data(), compressed_csr_view.get_nnz(), stream); - if (n_queries > 10) { - auto csr_view = raft::make_device_csr_matrix_view( - csr.get_elements().data(), compressed_csr_view); - - // create dataset view - auto dataset_view = raft::make_device_matrix_view( - idx.dataset().data_handle(), dim, n_dataset); - - // calc dot - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); - raft::sparse::linalg::sddmm(res, - queries, - dataset_view, - csr_view, - raft::linalg::Operation::NON_TRANSPOSE, - raft::linalg::Operation::NON_TRANSPOSE, - raft::make_host_scalar_view(&alpha), - raft::make_host_scalar_view(&beta)); - } else { - raft::sparse::distance::detail::faster_dot_on_csr(res, - csr.get_elements().data(), - compressed_csr_view.get_nnz(), - compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_indices().data(), - queries.data_handle(), - idx.dataset().data_handle(), - compressed_csr_view.get_n_rows(), - dim); - } + auto dataset_view = raft::make_device_matrix_view( + idx.dataset().data_handle(), n_dataset, dim); + + auto csr_view = raft::make_device_csr_matrix_view( + csr.get_elements().data(), compressed_csr_view); + + raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view); // post process std::optional> query_norms_;