Skip to content

Commit

Permalink
WIP: top_k tests
Browse files Browse the repository at this point in the history
The purpose of this PR is to continue several threads of discussion
regarding `top_k`.

This follows roughly the specifications of `top_k` in
data-apis/array-api#722, with slight modifications to the API:

```py
def topk(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...
```

Modifications:
- `mode: Literal["largest", "smallest"]` is replaced with
`largest: bool`
- `axis` is no longer a kw-only arg. This makes `torch.topk`
slightly more compatible.

The tests implemented here follows the proposed `top_k`
implementation at numpy/numpy#26666.
  • Loading branch information
JuliaPoo committed Jun 24, 2024
1 parent dbdca7b commit 8994765
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from hypothesis import given, note
from hypothesis import strategies as st
from hypothesis.control import assume

from . import _array_module as xp
from . import dtype_helpers as dh
Expand Down Expand Up @@ -203,3 +204,102 @@ def test_searchsorted(data):
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: shapes and values testing


@pytest.mark.unvectorized
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@given(
x=hh.arrays(
dtype=hh.real_dtypes,
shape=hh.shapes(min_dims=1, min_side=1),
elements={"allow_nan": False},
),
data=st.data()
)
def test_top_k(x, data):

if dh.is_float_dtype(x.dtype):
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))

axis = data.draw(
st.integers(-x.ndim, x.ndim - 1), label='axis')
largest = data.draw(st.booleans(), label='largest')
if axis is None:
k = data.draw(st.integers(1, math.prod(x.shape)))
else:
k = data.draw(st.integers(1, x.shape[axis]))

kw = dict(
x=x,
k=k,
axis=axis,
largest=largest,
)

(out_values, out_indices) = xp.top_k(x, k, axis, largest=largest)
if axis is None:
x = xp.reshape(x, (-1,))
axis = 0

ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype)
ph.assert_dtype(
"top_k",
in_dtype=x.dtype,
out_dtype=out_indices.dtype,
expected=dh.default_int
)
axes, = sh.normalise_axis(axis, x.ndim)
for arr in [out_values, out_indices]:
ph.assert_shape(
"top_k",
out_shape=arr.shape,
expected=x.shape[:axes] + (k,) + x.shape[axes + 1:],
kw=kw
)

scalar_type = dh.get_scalar_type(x.dtype)

for indices in sh.axes_ndindex(x.shape, (axes,)):

# Test if the values indexed by out_indices corresponds to
# the correct top_k values.
elements = [scalar_type(x[idx]) for idx in indices]
size = len(elements)
correct_order = sorted(
range(size),
key=elements.__getitem__,
reverse=largest
)
correct_order = correct_order[:k]
test_order = [out_indices[idx] for idx in indices[:k]]
# Sort because top_k does not necessarily return the values in
# sorted order.
test_sorted_order = sorted(
test_order,
key=elements.__getitem__,
reverse=largest
)

for y_o, x_o in zip(correct_order, test_sorted_order):
y_idx = indices[y_o]
x_idx = indices[x_o]
ph.assert_0d_equals(
"top_k",
x_repr=f"x[{x_idx}]",
x_val=x[x_idx],
out_repr=f"x[{y_idx}]",
out_val=x[y_idx],
kw=kw,
)

# Test if the values indexed by out_indices corresponds to out_values.
for y_o, x_idx in zip(test_order, indices[:k]):
y_idx = indices[y_o]
ph.assert_0d_equals(
"top_k",
x_repr=f"out_values[{x_idx}]",
x_val=scalar_type(out_values[x_idx]),
out_repr=f"x[{y_idx}]",
out_val=x[y_idx],
kw=kw
)

0 comments on commit 8994765

Please sign in to comment.