diff --git a/sklearn_numba_dpex/common/topk.py b/sklearn_numba_dpex/common/topk.py index fb7d984..0b4bd49 100644 --- a/sklearn_numba_dpex/common/topk.py +++ b/sklearn_numba_dpex/common/topk.py @@ -137,16 +137,17 @@ def topk(array_in, k, group_sizes=None): The output is not deterministic: the order of the output is undefined. Successive calls can return the same items in different order. """ - _get_topk_kernel = _make_get_topk_kernel( + _initialize_result, _get_topk_kernel = _make_get_topk_kernel( k, array_in.shape, array_in.dtype.type, array_in.device.sycl_device, - group_sizes, output="values", + group_sizes=group_sizes, ) - return _get_topk_kernel(array_in) + result = _initialize_result() + return _get_topk_kernel(array_in, result) def topk_idx(array_in, k, group_sizes=None): @@ -184,25 +185,36 @@ def topk_idx(array_in, k, group_sizes=None): for this value can be different between two successive calls. """ - _get_topk_kernel = _make_get_topk_kernel( + _initialize_result, _get_topk_kernel = _make_get_topk_kernel( k, array_in.shape, array_in.dtype.type, array_in.device.sycl_device, - group_sizes, output="idx", + group_sizes=group_sizes, ) + result = _initialize_result() + return _get_topk_kernel(array_in, result) - return _get_topk_kernel(array_in) +1 + +@lru_cache def _make_get_topk_kernel( - k, shape, dtype, device, group_sizes, output, reuse_result_buffer=False + k, + shape, + dtype, + device, + output, + group_sizes=None, + return_result_initializer=True, ): """Returns a `_get_topk_kernel` closure. - The closure can be passed an array with attributes `shape`, `dtype` and `device` - and will perform a TopK search, returning requested top-k items. + The closure can be passed an array with attributes `shape`, `dtype` and `device`, + along with a result array, and will perform a TopK search, returning requested + top-k items stored in the result array. As long as a closure is referenced, it keeps in cache pre-allocated buffers and pre-defined kernel functions. Thus, it is more efficient to perform sequential @@ -213,10 +225,9 @@ def _make_get_topk_kernel( instead. They include definition of kernels, allocation of buffers, and cleaning of said allocations afterwards. - By default, the memory allocation for the result array is not reused. This is to - avoid a previously computed result to be erased by a subsequent call to the same - closure without the user noticing. Reusing the same buffer can still be enforced by - setting `reuse_result_buffer=True`. + `_make_get_topk_kernel` also returns an initializer closure for the result array. + This is optional and is deactivated by setting `return_result_initializer` + parameter to `False`, then the user can use instead a preferred buffer. """ # TODO: it seems a kernel specialized for 1d arrays would show 10-20% better # performance. If this case becomes specifically relevant, consider implementing @@ -237,10 +248,17 @@ def _make_get_topk_kernel( _initialize_result_col_idx, gather_results_kernel, ) = _get_gather_results_kernels( - n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer + n_rows, + n_cols, + k, + work_group_size, + dtype, + device, + output, + return_result_initializer, ) - def _get_topk(array_in): + def _get_topk(array_in, result, offset=0): if is_1d: array_in = dpt.reshape(array_in, (1, -1)) @@ -248,10 +266,10 @@ def _get_topk(array_in): threshold, n_threshold_occurences_in_topk, n_threshold_occurences_in_data, - ) = get_topk_threshold(array_in) + ) = get_topk_threshold(array_in, offset) + # TODO: can be optimized if array_in.shape[0] < n_rows result_col_idx = _initialize_result_col_idx() - result = _initialize_result(array_in.dtype.type) gather_results_kernel( array_in, @@ -268,47 +286,36 @@ def _get_topk(array_in): return result - return _get_topk + return _initialize_result, _get_topk @lru_cache def _get_gather_results_kernels( - n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer + n_rows, n_cols, k, work_group_size, dtype, device, output, return_result_initializer ): + _initialize_result = None + if output == "values": gather_results_kernel = _make_gather_topk_kernel( - n_rows, n_cols, k, work_group_size, ) - if reuse_result_buffer: - result = dpt.empty((n_rows, k), dtype=dtype, device=device) - def _initialize_result(dtype): - return result + if return_result_initializer: - else: - - def _initialize_result(dtype): + def _initialize_result(): return dpt.empty((n_rows, k), dtype=dtype, device=device) elif output == "idx": gather_results_kernel = _make_gather_topk_idx_kernel( - n_rows, n_cols, k, work_group_size, ) - if reuse_result_buffer: - result = dpt.empty((n_rows, k), dtype=np.int64, device=device) - - def _initialize_result(dtype): - return result + if return_result_initializer: - else: - - def _initialize_result(dtype): + def _initialize_result(): return dpt.empty((n_rows, k), dtype=np.int64, device=device) elif output == "values+idx": @@ -551,7 +558,7 @@ def initialize_radix_mask(): fill_value=0, shape=(n_rows,), work_group_size=work_group_size, dtype=np.int64 ) - def _get_topk_threshold(array_in): + def _get_topk_threshold(array_in, offset=0): # Use variables that are local to the closure, so it can be manipulated more # easily in the main loop k_in_subset, n_active_rows, new_n_active_rows, threshold_count = ( @@ -563,7 +570,9 @@ def _get_topk_threshold(array_in): # Initialize all buffers initialize_k_in_subset_kernel(k_in_subset) - n_active_rows[0] = n_active_rows_scalar = n_rows + + # TODO: a few things can be optimized if effective_n_rows < n_rows + n_active_rows[0] = n_active_rows_scalar = effective_n_rows = array_in.shape[0] initialize_threshold_count_kernel(threshold_count) active_rows_mapping, new_active_rows_mapping = initialize_active_rows_mapping() desired_masked_value = initialize_desired_masked_value() @@ -571,9 +580,10 @@ def _get_topk_threshold(array_in): # Reinterpret input as uint so we can use bitwise compute array_in_uint = dpt.usm_ndarray( - shape=(n_rows, n_cols), + shape=(effective_n_rows, n_cols), dtype=uint_type, - buffer=array_in, + buffer=array_in.usm_data, + offset=offset * n_cols, ) # The main loop: each iteration consists in sorting partially the data on the @@ -1275,7 +1285,6 @@ def _check_radix_histogram( @lru_cache def _make_gather_topk_kernel( - n_rows, n_cols, k, work_group_size, @@ -1288,7 +1297,6 @@ def _make_gather_topk_kernel( """ n_work_groups_per_row = math.ceil(n_cols / work_group_size) work_group_shape = (1, work_group_size) - global_shape = (n_rows, n_work_groups_per_row * work_group_size) @dpex.kernel # fmt: off @@ -1297,7 +1305,7 @@ def gather_topk( threshold, # IN (n_rows,) n_threshold_occurences_in_topk, # IN (n_rows,) n_threshold_occurences_in_data, # IN (n_rows,) - result_col_idx, # BUFFER (n_rows,) + result_col_idx, # BUFFER (n_rows,) result, # OUT (n_rows, k) ): # fmt: on @@ -1393,12 +1401,31 @@ def gather_topk_generic( result_col_idx_ = dpex.atomic.add(result_col_idx, row_idx, count_one_as_an_int) result[row_idx, result_col_idx_] = item - return gather_topk[global_shape, work_group_shape] + # TODO: write decorator instead + def _gather_topk( + array_in, + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + result_col_idx, + result, + ): + n_rows = array_in.shape[0] + global_shape = (n_rows, n_work_groups_per_row * work_group_size) + gather_topk[global_shape, work_group_shape]( + array_in, + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + result_col_idx, + result, + ) + + return _gather_topk @lru_cache def _make_gather_topk_idx_kernel( - n_rows, n_cols, k, work_group_size, @@ -1406,7 +1433,6 @@ def _make_gather_topk_idx_kernel( """Same than gather_topk kernel but return top-k indices rather than top-k values""" n_work_groups_per_row = math.ceil(n_cols / work_group_size) work_group_shape = (1, work_group_size) - global_shape = (n_rows, n_work_groups_per_row * work_group_size) @dpex.kernel # fmt: off @@ -1512,4 +1538,24 @@ def gather_topk_idx_generic( ) result[row_idx, result_col_idx_] = col_idx - return gather_topk_idx[global_shape, work_group_shape] + # TODO: write decorator instead + def _gather_topk_idx( + array_in, + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + result_col_idx, + result, + ): + n_rows = array_in.shape[0] + global_shape = (n_rows, n_work_groups_per_row * work_group_size) + gather_topk_idx[global_shape, work_group_shape]( + array_in, + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + result_col_idx, + result, + ) + + return _gather_topk_idx diff --git a/sklearn_numba_dpex/knn/__init__.py b/sklearn_numba_dpex/knn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sklearn_numba_dpex/knn/drivers.py b/sklearn_numba_dpex/knn/drivers.py new file mode 100644 index 0000000..366b533 --- /dev/null +++ b/sklearn_numba_dpex/knn/drivers.py @@ -0,0 +1,103 @@ +import math + +import dpctl.tensor as dpt +import numpy as np + +from sklearn_numba_dpex.common.matmul import make_matmul_2d_kernel +from sklearn_numba_dpex.common.topk import _make_get_topk_kernel + + +def kneighbors( + query, + data, + n_neighbors, + metric="euclidean", + return_distance=False, + maximum_compute_buffer_size=1073741824, # 1 GiB +): + n_queries, n_features = query.shape + n_samples = data.shape[0] + compute_dtype = query.dtype.type + compute_dtype_itemsize = np.dtype(compute_dtype).itemsize + device = query.device.sycl_device + + index_itemsize = np.dtype(np.int64).itemsize + + pairwise_distance_required_bytes_per_query = n_samples * compute_dtype_itemsize + + # TODO: better way to ensure this remains synchronized with future changes to + # TopK ? + topk_required_bytes_per_query = ( + (n_neighbors + 1 + 1 + 1) * compute_dtype_itemsize + ) + ((1 + 1 + 2 + 4) * index_itemsize) + if return_distance: + topk_required_bytes_per_query += n_neighbors * index_itemsize + + total_required_bytes_per_query = ( + pairwise_distance_required_bytes_per_query + topk_required_bytes_per_query + ) + + max_slice_size = maximum_compute_buffer_size / total_required_bytes_per_query + + if max_slice_size < 1: + raise RuntimeError("Buffer size is too small") + + slice_size = min(math.floor(max_slice_size), n_queries) + + n_slices = math.ceil(n_queries / slice_size) + n_full_slices = n_slices - 1 + last_slice_size = ((n_queries - 1) % slice_size) + 1 + + def pairwise_distance_multiply_fn(x, y): + diff = x - y + return diff * diff + + def negative_value(x): + return -x + + squared_pairwise_distance_kernel = make_matmul_2d_kernel( + slice_size, + n_samples, + n_features, + compute_dtype, + device, + multiply_fn=pairwise_distance_multiply_fn, + out_fused_elementwise_fn=negative_value, + ) + squared_pairwise_distance_buffer = dpt.empty( + (slice_size, n_samples), dtype=compute_dtype, device=device + ) + + _, get_topk_kernel = _make_get_topk_kernel( + n_neighbors, (slice_size, n_samples), compute_dtype, device, output="idx" + ) + + result = dpt.empty((n_queries, n_neighbors), dtype=compute_dtype, device=device) + + slice_sample_idx = 0 + for _ in range(n_full_slices): + query_slice = query[slice_sample_idx : (slice_sample_idx + slice_size)] + squared_pairwise_distance_kernel( + query_slice, data, squared_pairwise_distance_buffer + ) + + result_slice = result[slice_sample_idx : (slice_sample_idx + slice_size)] + get_topk_kernel(squared_pairwise_distance_buffer, result_slice) + + slice_sample_idx += slice_size + + # NB: some pairwise distance are computed twice for no reason but it's cheaper than + # to re-compile a kernel specifically for the last slice. + query_slice = query[-slice_size:] + squared_pairwise_distance_kernel( + query_slice, data, squared_pairwise_distance_buffer + ) + result_slice = result[-last_slice_size:] + + get_topk_kernel( + squared_pairwise_distance_buffer[-last_slice_size:], + result_slice, + offset=slice_size - last_slice_size, + ) + + return result