Skip to content

Commit

Permalink
Added more extensive tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Dec 12, 2023
1 parent d500123 commit f6ea929
Showing 1 changed file with 84 additions and 12 deletions.
96 changes: 84 additions & 12 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,20 @@
from numba import errors

from numba_dpex import dpjit
from numba_dpex.tests._helper import get_all_dtypes


def get_xfail_test(param, reason):
return pytest.param(
param,
marks=pytest.mark.xfail(reason=reason),
)

from numba_dpex.tests._helper import get_all_dtypes, get_xfail_test

# Get all dtypes, except bool, float16 and complex
dtypes = get_all_dtypes(
no_bool=True, no_float16=True, no_none=False, no_complex=True
)
# Get all dtypes, except bool, float16, None and complex
dtypes_no_none = get_all_dtypes(
no_bool=True, no_float16=True, no_none=True, no_complex=True
)
# Get all dtypes, except bool, float16, None, int (all) and complex
dtypes_float_only = get_all_dtypes(
no_bool=True, no_float16=True, no_int=True, no_none=True, no_complex=True
)
usm_types = ["device", "shared", "host"]
endpoints = [True, False]
ranges = [
Expand All @@ -40,12 +38,16 @@ def get_xfail_test(param, reason):
[0, 1, 17],
[1, 0, 17],
[-1, -1, 10],
[0.0, 0.5, 23], # 10
# [-0.5, 0.0, 10] # noqa: E800
[
0.0,
0.5,
23,
], # fails dtype=np.int32, dpnp/np results don't make sense # 10
[-0.5, 0.0, 10], # fails dtype=np.int32, dpnp/np results don't make sense
]


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("range", ranges[:-2])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
def test_dpnp_linspace_default(range, dtype, endpoint):
Expand Down Expand Up @@ -78,6 +80,39 @@ def func():
)


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("dtype", dtypes_float_only)
@pytest.mark.parametrize("endpoint", endpoints)
def test_dpnp_linspace_default_float_only(range, dtype, endpoint):
start, stop, num = range

@dpjit
def func():
x = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
return x

try:
c = func()
except Exception:
pytest.fail("Calling dpnp.linspace() inside dpjit failed.")

a = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)

assert a.dtype == c.dtype
assert a.shape == c.shape
if a.dtype in [dpnp.float, dpnp.float16, dpnp.float32, dpnp.float64]:
assert np.allclose(a.asnumpy(), c.asnumpy())
else:
assert np.array_equal(a.asnumpy(), c.asnumpy())
if c.sycl_queue != a.sycl_queue:
pytest.xfail(
"Returned queue does not have the same queue as in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
a.sycl_device
)


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
Expand Down Expand Up @@ -147,3 +182,40 @@ def func(queue):
pytest.xfail(
"Returned queue does not have the same queue as the one passed to the dpnp function."
)


@pytest.mark.parametrize("range", ranges[0:-2])
@pytest.mark.parametrize("start_dtype", dtypes_no_none)
@pytest.mark.parametrize("stop_dtype", dtypes_no_none)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
def test_dpnp_linspace_default_dtype_perm(
range, start_dtype, stop_dtype, dtype, endpoint
):
start, stop, num = start_dtype(range[0]), stop_dtype(range[1]), range[2]

@dpjit
def func():
x = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
return x

try:
c = func()
except Exception:
pytest.fail("Calling dpnp.linspace() inside dpjit failed.")

a = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)

assert a.dtype == c.dtype
assert a.shape == c.shape
if a.dtype in [dpnp.float, dpnp.float16, dpnp.float32, dpnp.float64]:
assert np.allclose(a.asnumpy(), c.asnumpy())
else:
assert np.array_equal(a.asnumpy(), c.asnumpy())
if c.sycl_queue != a.sycl_queue:
pytest.xfail(
"Returned queue does not have the same queue as in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
a.sycl_device
)

0 comments on commit f6ea929

Please sign in to comment.