Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA kneighbors driver #111

Draft
wants to merge 7 commits into
base: enh/reuse_topk_buffers
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 93 additions & 47 deletions sklearn_numba_dpex/common/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,17 @@ def topk(array_in, k, group_sizes=None):
The output is not deterministic: the order of the output is undefined. Successive
calls can return the same items in different order.
"""
_get_topk_kernel = _make_get_topk_kernel(
_initialize_result, _get_topk_kernel = _make_get_topk_kernel(
k,
array_in.shape,
array_in.dtype.type,
array_in.device.sycl_device,
group_sizes,
output="values",
group_sizes=group_sizes,
)

return _get_topk_kernel(array_in)
result = _initialize_result()
return _get_topk_kernel(array_in, result)


def topk_idx(array_in, k, group_sizes=None):
Expand Down Expand Up @@ -184,25 +185,36 @@ def topk_idx(array_in, k, group_sizes=None):
for this value can be different between two successive calls.

"""
_get_topk_kernel = _make_get_topk_kernel(
_initialize_result, _get_topk_kernel = _make_get_topk_kernel(
k,
array_in.shape,
array_in.dtype.type,
array_in.device.sycl_device,
group_sizes,
output="idx",
group_sizes=group_sizes,
)
result = _initialize_result()
return _get_topk_kernel(array_in, result)

return _get_topk_kernel(array_in)

1


@lru_cache
def _make_get_topk_kernel(
k, shape, dtype, device, group_sizes, output, reuse_result_buffer=False
k,
shape,
dtype,
device,
output,
group_sizes=None,
return_result_initializer=True,
):
"""Returns a `_get_topk_kernel` closure.

The closure can be passed an array with attributes `shape`, `dtype` and `device`
and will perform a TopK search, returning requested top-k items.
The closure can be passed an array with attributes `shape`, `dtype` and `device`,
along with a result array, and will perform a TopK search, returning requested
top-k items stored in the result array.

As long as a closure is referenced, it keeps in cache pre-allocated buffers and
pre-defined kernel functions. Thus, it is more efficient to perform sequential
Expand All @@ -213,10 +225,9 @@ def _make_get_topk_kernel(
instead. They include definition of kernels, allocation of buffers, and
cleaning of said allocations afterwards.

By default, the memory allocation for the result array is not reused. This is to
avoid a previously computed result to be erased by a subsequent call to the same
closure without the user noticing. Reusing the same buffer can still be enforced by
setting `reuse_result_buffer=True`.
`_make_get_topk_kernel` also returns an initializer closure for the result array.
This is optional and is deactivated by setting `return_result_initializer`
parameter to `False`, then the user can use instead a preferred buffer.
"""
# TODO: it seems a kernel specialized for 1d arrays would show 10-20% better
# performance. If this case becomes specifically relevant, consider implementing
Expand All @@ -237,21 +248,28 @@ def _make_get_topk_kernel(
_initialize_result_col_idx,
gather_results_kernel,
) = _get_gather_results_kernels(
n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer
n_rows,
n_cols,
k,
work_group_size,
dtype,
device,
output,
return_result_initializer,
)

def _get_topk(array_in):
def _get_topk(array_in, result, offset=0):
if is_1d:
array_in = dpt.reshape(array_in, (1, -1))

(
threshold,
n_threshold_occurences_in_topk,
n_threshold_occurences_in_data,
) = get_topk_threshold(array_in)
) = get_topk_threshold(array_in, offset)

# TODO: can be optimized if array_in.shape[0] < n_rows
result_col_idx = _initialize_result_col_idx()
result = _initialize_result(array_in.dtype.type)

gather_results_kernel(
array_in,
Expand All @@ -268,47 +286,36 @@ def _get_topk(array_in):

return result

return _get_topk
return _initialize_result, _get_topk


@lru_cache
def _get_gather_results_kernels(
n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer
n_rows, n_cols, k, work_group_size, dtype, device, output, return_result_initializer
):
_initialize_result = None

if output == "values":
gather_results_kernel = _make_gather_topk_kernel(
n_rows,
n_cols,
k,
work_group_size,
)
if reuse_result_buffer:
result = dpt.empty((n_rows, k), dtype=dtype, device=device)

def _initialize_result(dtype):
return result
if return_result_initializer:

else:

def _initialize_result(dtype):
def _initialize_result():
return dpt.empty((n_rows, k), dtype=dtype, device=device)

elif output == "idx":
gather_results_kernel = _make_gather_topk_idx_kernel(
n_rows,
n_cols,
k,
work_group_size,
)
if reuse_result_buffer:
result = dpt.empty((n_rows, k), dtype=np.int64, device=device)

def _initialize_result(dtype):
return result
if return_result_initializer:

else:

def _initialize_result(dtype):
def _initialize_result():
return dpt.empty((n_rows, k), dtype=np.int64, device=device)

elif output == "values+idx":
Expand Down Expand Up @@ -551,7 +558,7 @@ def initialize_radix_mask():
fill_value=0, shape=(n_rows,), work_group_size=work_group_size, dtype=np.int64
)

def _get_topk_threshold(array_in):
def _get_topk_threshold(array_in, offset=0):
# Use variables that are local to the closure, so it can be manipulated more
# easily in the main loop
k_in_subset, n_active_rows, new_n_active_rows, threshold_count = (
Expand All @@ -563,17 +570,20 @@ def _get_topk_threshold(array_in):

# Initialize all buffers
initialize_k_in_subset_kernel(k_in_subset)
n_active_rows[0] = n_active_rows_scalar = n_rows

# TODO: a few things can be optimized if effective_n_rows < n_rows
n_active_rows[0] = n_active_rows_scalar = effective_n_rows = array_in.shape[0]
initialize_threshold_count_kernel(threshold_count)
active_rows_mapping, new_active_rows_mapping = initialize_active_rows_mapping()
desired_masked_value = initialize_desired_masked_value()
radix_position, mask_for_desired_value = initialize_radix_mask()

# Reinterpret input as uint so we can use bitwise compute
array_in_uint = dpt.usm_ndarray(
shape=(n_rows, n_cols),
shape=(effective_n_rows, n_cols),
dtype=uint_type,
buffer=array_in,
buffer=array_in.usm_data,
offset=offset * n_cols,
)

# The main loop: each iteration consists in sorting partially the data on the
Expand Down Expand Up @@ -1275,7 +1285,6 @@ def _check_radix_histogram(

@lru_cache
def _make_gather_topk_kernel(
n_rows,
n_cols,
k,
work_group_size,
Expand All @@ -1288,7 +1297,6 @@ def _make_gather_topk_kernel(
"""
n_work_groups_per_row = math.ceil(n_cols / work_group_size)
work_group_shape = (1, work_group_size)
global_shape = (n_rows, n_work_groups_per_row * work_group_size)

@dpex.kernel
# fmt: off
Expand All @@ -1297,7 +1305,7 @@ def gather_topk(
threshold, # IN (n_rows,)
n_threshold_occurences_in_topk, # IN (n_rows,)
n_threshold_occurences_in_data, # IN (n_rows,)
result_col_idx, # BUFFER (n_rows,)
result_col_idx, # BUFFER (n_rows,)
result, # OUT (n_rows, k)
):
# fmt: on
Expand Down Expand Up @@ -1393,20 +1401,38 @@ def gather_topk_generic(
result_col_idx_ = dpex.atomic.add(result_col_idx, row_idx, count_one_as_an_int)
result[row_idx, result_col_idx_] = item

return gather_topk[global_shape, work_group_shape]
# TODO: write decorator instead
def _gather_topk(
array_in,
threshold,
n_threshold_occurences_in_topk,
n_threshold_occurences_in_data,
result_col_idx,
result,
):
n_rows = array_in.shape[0]
global_shape = (n_rows, n_work_groups_per_row * work_group_size)
gather_topk[global_shape, work_group_shape](
array_in,
threshold,
n_threshold_occurences_in_topk,
n_threshold_occurences_in_data,
result_col_idx,
result,
)

return _gather_topk


@lru_cache
def _make_gather_topk_idx_kernel(
n_rows,
n_cols,
k,
work_group_size,
):
"""Same than gather_topk kernel but return top-k indices rather than top-k values"""
n_work_groups_per_row = math.ceil(n_cols / work_group_size)
work_group_shape = (1, work_group_size)
global_shape = (n_rows, n_work_groups_per_row * work_group_size)

@dpex.kernel
# fmt: off
Expand Down Expand Up @@ -1512,4 +1538,24 @@ def gather_topk_idx_generic(
)
result[row_idx, result_col_idx_] = col_idx

return gather_topk_idx[global_shape, work_group_shape]
# TODO: write decorator instead
def _gather_topk_idx(
array_in,
threshold,
n_threshold_occurences_in_topk,
n_threshold_occurences_in_data,
result_col_idx,
result,
):
n_rows = array_in.shape[0]
global_shape = (n_rows, n_work_groups_per_row * work_group_size)
gather_topk_idx[global_shape, work_group_shape](
array_in,
threshold,
n_threshold_occurences_in_topk,
n_threshold_occurences_in_data,
result_col_idx,
result,
)

return _gather_topk_idx
Empty file.
103 changes: 103 additions & 0 deletions sklearn_numba_dpex/knn/drivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import math

import dpctl.tensor as dpt
import numpy as np

from sklearn_numba_dpex.common.matmul import make_matmul_2d_kernel
from sklearn_numba_dpex.common.topk import _make_get_topk_kernel


def kneighbors(
query,
data,
n_neighbors,
metric="euclidean",
return_distance=False,
maximum_compute_buffer_size=1073741824, # 1 GiB
):
n_queries, n_features = query.shape
n_samples = data.shape[0]
compute_dtype = query.dtype.type
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize
device = query.device.sycl_device

index_itemsize = np.dtype(np.int64).itemsize

pairwise_distance_required_bytes_per_query = n_samples * compute_dtype_itemsize

# TODO: better way to ensure this remains synchronized with future changes to
# TopK ?
topk_required_bytes_per_query = (
(n_neighbors + 1 + 1 + 1) * compute_dtype_itemsize
) + ((1 + 1 + 2 + 4) * index_itemsize)
if return_distance:
topk_required_bytes_per_query += n_neighbors * index_itemsize

total_required_bytes_per_query = (
pairwise_distance_required_bytes_per_query + topk_required_bytes_per_query
)

max_slice_size = maximum_compute_buffer_size / total_required_bytes_per_query

if max_slice_size < 1:
raise RuntimeError("Buffer size is too small")

slice_size = min(math.floor(max_slice_size), n_queries)

n_slices = math.ceil(n_queries / slice_size)
n_full_slices = n_slices - 1
last_slice_size = ((n_queries - 1) % slice_size) + 1

def pairwise_distance_multiply_fn(x, y):
diff = x - y
return diff * diff

def negative_value(x):
return -x

squared_pairwise_distance_kernel = make_matmul_2d_kernel(
slice_size,
n_samples,
n_features,
compute_dtype,
device,
multiply_fn=pairwise_distance_multiply_fn,
out_fused_elementwise_fn=negative_value,
)
squared_pairwise_distance_buffer = dpt.empty(
(slice_size, n_samples), dtype=compute_dtype, device=device
)

_, get_topk_kernel = _make_get_topk_kernel(
n_neighbors, (slice_size, n_samples), compute_dtype, device, output="idx"
)

result = dpt.empty((n_queries, n_neighbors), dtype=compute_dtype, device=device)

slice_sample_idx = 0
for _ in range(n_full_slices):
query_slice = query[slice_sample_idx : (slice_sample_idx + slice_size)]
squared_pairwise_distance_kernel(
query_slice, data, squared_pairwise_distance_buffer
)

result_slice = result[slice_sample_idx : (slice_sample_idx + slice_size)]
get_topk_kernel(squared_pairwise_distance_buffer, result_slice)

slice_sample_idx += slice_size

# NB: some pairwise distance are computed twice for no reason but it's cheaper than
# to re-compile a kernel specifically for the last slice.
query_slice = query[-slice_size:]
squared_pairwise_distance_kernel(
query_slice, data, squared_pairwise_distance_buffer
)
result_slice = result[-last_slice_size:]

get_topk_kernel(
squared_pairwise_distance_buffer[-last_slice_size:],
result_slice,
offset=slice_size - last_slice_size,
)

return result