diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py index 8658e1f9a0..b1570c1cef 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py @@ -11,22 +11,20 @@ from numba import errors from numba_dpex import dpjit -from numba_dpex.tests._helper import get_all_dtypes - - -def get_xfail_test(param, reason): - return pytest.param( - param, - marks=pytest.mark.xfail(reason=reason), - ) - +from numba_dpex.tests._helper import get_all_dtypes, get_xfail_test +# Get all dtypes, except bool, float16 and complex dtypes = get_all_dtypes( no_bool=True, no_float16=True, no_none=False, no_complex=True ) +# Get all dtypes, except bool, float16, None and complex dtypes_no_none = get_all_dtypes( no_bool=True, no_float16=True, no_none=True, no_complex=True ) +# Get all dtypes, except bool, float16, None, int (all) and complex +dtypes_float_only = get_all_dtypes( + no_bool=True, no_float16=True, no_int=True, no_none=True, no_complex=True +) usm_types = ["device", "shared", "host"] endpoints = [True, False] ranges = [ @@ -40,12 +38,16 @@ def get_xfail_test(param, reason): [0, 1, 17], [1, 0, 17], [-1, -1, 10], - [0.0, 0.5, 23], # 10 - # [-0.5, 0.0, 10] # noqa: E800 + [ + 0.0, + 0.5, + 23, + ], # fails dtype=np.int32, dpnp/np results don't make sense # 10 + [-0.5, 0.0, 10], # fails dtype=np.int32, dpnp/np results don't make sense ] -@pytest.mark.parametrize("range", ranges) +@pytest.mark.parametrize("range", ranges[:-2]) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("endpoint", endpoints) def test_dpnp_linspace_default(range, dtype, endpoint): @@ -78,6 +80,39 @@ def func(): ) +@pytest.mark.parametrize("range", ranges) +@pytest.mark.parametrize("dtype", dtypes_float_only) +@pytest.mark.parametrize("endpoint", endpoints) +def test_dpnp_linspace_default_float_only(range, dtype, endpoint): + start, stop, num = range + + @dpjit + def func(): + x = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + return x + + try: + c = func() + except Exception: + pytest.fail("Calling dpnp.linspace() inside dpjit failed.") + + a = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + + assert a.dtype == c.dtype + assert a.shape == c.shape + if a.dtype in [dpnp.float, dpnp.float16, dpnp.float32, dpnp.float64]: + assert np.allclose(a.asnumpy(), c.asnumpy()) + else: + assert np.array_equal(a.asnumpy(), c.asnumpy()) + if c.sycl_queue != a.sycl_queue: + pytest.xfail( + "Returned queue does not have the same queue as in the dummy array." + ) + assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue( + a.sycl_device + ) + + @pytest.mark.parametrize("range", ranges) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("endpoint", endpoints) @@ -147,3 +182,40 @@ def func(queue): pytest.xfail( "Returned queue does not have the same queue as the one passed to the dpnp function." ) + + +@pytest.mark.parametrize("range", ranges[0:-2]) +@pytest.mark.parametrize("start_dtype", dtypes_no_none) +@pytest.mark.parametrize("stop_dtype", dtypes_no_none) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("endpoint", endpoints) +def test_dpnp_linspace_default_dtype_perm( + range, start_dtype, stop_dtype, dtype, endpoint +): + start, stop, num = start_dtype(range[0]), stop_dtype(range[1]), range[2] + + @dpjit + def func(): + x = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + return x + + try: + c = func() + except Exception: + pytest.fail("Calling dpnp.linspace() inside dpjit failed.") + + a = dpnp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + + assert a.dtype == c.dtype + assert a.shape == c.shape + if a.dtype in [dpnp.float, dpnp.float16, dpnp.float32, dpnp.float64]: + assert np.allclose(a.asnumpy(), c.asnumpy()) + else: + assert np.array_equal(a.asnumpy(), c.asnumpy()) + if c.sycl_queue != a.sycl_queue: + pytest.xfail( + "Returned queue does not have the same queue as in the dummy array." + ) + assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue( + a.sycl_device + )