Skip to content

Commit

Permalink
expose col-major bfknn to python (#575)
Browse files Browse the repository at this point in the history
Follow on to #572 -

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #575
  • Loading branch information
benfred authored Jan 16, 2025
1 parent 47d71c3 commit c49ba7b
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 15 deletions.
2 changes: 1 addition & 1 deletion cpp/src/distance/pairwise_distance_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res,

if ((x_row_major != y_row_major) || (x_row_major != distances_row_major)) {
RAFT_FAIL(
"Inputs to cuvsPairwiseDistance must all have the same layout (row-major or col-major");
"Inputs to cuvsPairwiseDistance must all have the same layout (row-major or col-major)");
}

if (x_row_major) {
Expand Down
28 changes: 21 additions & 7 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@

namespace {

template <typename T>
template <typename T, typename LayoutT = raft::row_major>
void* _build(cuvsResources_t res,
DLManagedTensor* dataset_tensor,
cuvsDistanceType metric,
T metric_arg)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);

using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using mdspan_type = raft::device_matrix_view<T const, int64_t, LayoutT>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor);

cuvs::neighbors::brute_force::index_params params;
Expand All @@ -53,7 +53,7 @@ void* _build(cuvsResources_t res,
return index_on_heap;
}

template <typename T>
template <typename T, typename QueriesLayoutT = raft::row_major>
void _search(cuvsResources_t res,
cuvsBruteForceIndex index,
DLManagedTensor* queries_tensor,
Expand All @@ -64,7 +64,7 @@ void _search(cuvsResources_t res,
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T>*>(index.addr);

using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
Expand Down Expand Up @@ -150,8 +150,15 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float>(res, dataset_tensor, metric, metric_arg));
if (cuvs::core::is_c_contiguous(dataset_tensor)) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float>(res, dataset_tensor, metric, metric_arg));
} else if (cuvs::core::is_f_contiguous(dataset_tensor)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<float, raft::col_major>(res, dataset_tensor, metric, metric_arg));
} else {
RAFT_FAIL("dataset input to cuvsBruteForceBuild must be contiguous (non-strided)");
}
index->dtype = dataset.dtype;
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
Expand Down Expand Up @@ -189,7 +196,14 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
if (cuvs::core::is_c_contiguous(queries_tensor)) {
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else if (cuvs::core::is_f_contiguous(queries_tensor)) {
_search<float, raft::col_major>(
res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else {
RAFT_FAIL("queries input to cuvsBruteForceSearch must be contiguous (non-strided)");
}
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
4 changes: 2 additions & 2 deletions python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build(dataset, metric="sqeuclidean", metric_arg=2.0, resources=None):
"""

dataset_ai = wrap_array(dataset)
_check_input_array(dataset_ai, [np.dtype('float32')])
_check_input_array(dataset_ai, [np.dtype('float32')], exp_row_major=False)

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

Expand Down Expand Up @@ -218,7 +218,7 @@ def search(Index index,
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

queries_cai = wrap_array(queries)
_check_input_array(queries_cai, [np.dtype('float32')])
_check_input_array(queries_cai, [np.dtype('float32')], exp_row_major=False)

cdef uint32_t n_queries = queries_cai.shape[0]

Expand Down
6 changes: 4 additions & 2 deletions python/cuvs/cuvs/neighbors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.


def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None):
def _check_input_array(
cai, exp_dt, exp_rows=None, exp_cols=None, exp_row_major=True
):
if cai.dtype not in exp_dt:
raise TypeError("dtype %s not supported" % cai.dtype)

if not cai.c_contiguous:
if exp_row_major and not cai.c_contiguous:
raise ValueError("Row major input is expected")

if exp_cols is not None and cai.shape[1] != exp_cols:
Expand Down
9 changes: 6 additions & 3 deletions python/cuvs/cuvs/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@
],
)
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32])
def test_brute_force_knn(
n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype
n_index_rows, n_query_rows, n_cols, k, inplace, order, metric, dtype
):
index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype)
index = np.random.random_sample((n_index_rows, n_cols))
index = np.asarray(index, order=order).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols))
queries = np.asarray(queries, order=order).astype(dtype)

# RussellRao expects boolean arrays
if metric == "russellrao":
Expand Down

0 comments on commit c49ba7b

Please sign in to comment.