Skip to content

Commit

Permalink
Added all docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Dec 20, 2023
1 parent cec1aee commit 5f5de6b
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 39 deletions.
98 changes: 66 additions & 32 deletions numba_dpex/dpnp_iface/array_interval_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,12 @@ def impl_dpnp_linspace(
Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
ty_start (numba.core.types.scalars.Integer): Numba type for the start
of the interval.
ty_stop (numba.core.types.scalars.Integer): Numba type for the end
of the interval.
ty_step (numba.core.types.scalars.Integer): Numba type for the step
of the interval.
ty_start (numba.core.types.scalars.*): Numba type for the start of the
interval.
ty_stop (numba.core.types.scalars.*): Numba type for the end of the
interval.
ty_num (numba.core.types.scalars.Integer): Numba type for the number of
items in the interval.
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
dtype.
ty_device (numba.core.types.misc.UnicodeType): UnicodeType
Expand Down Expand Up @@ -749,10 +749,18 @@ def codegen(context, builder, sig, args):

# Extend or truncate input values w.r.t. destination array type
start_ir = _normalize(
builder, start_ir, start_arg_type, dtype_arg_type.dtype
builder,
start_ir,
start_arg_type,
dtype_arg_type.dtype,
float_to_int_rounding=True,
)
stop_ir = _normalize(
builder, stop_ir, stop_arg_type, dtype_arg_type.dtype
builder,
stop_ir,
stop_arg_type,
dtype_arg_type.dtype,
float_to_int_rounding=True,
)

# After normalization, their arg_types will change
Expand Down Expand Up @@ -826,26 +834,42 @@ def ol_dpnp_linspace(
axis=0,
):
"""Implementation of an overload to support dpnp.linspace() inside
a dpjit function. Returns evenly spaced values within the half-open interval
[start, stop) as a one-dimensional array.
a dpjit function. Returns evenly spaced numbers over a specified interval.
Let $N$ be the number of generated values (which is either `num` or `num+1`
depending on whether endpoint is `True` or `False`, respectively). For
real-valued output arrays, the spacing between values is given by
$$
\\Delta_{\\text{real}} = \\frac{stop - start}{N-1}
$$
For complex output arrays, let `a = real(start)`, `b = imag(start)`,
`c = real(stop)`, and `d = imag(stop)`. The spacing between complex values
is given by
$$
\\Delta_{\\text{complex}} = \\frac{c-a}{N-1} + \\frac{d-b}{N-1}j
$$
Args:
start (numba.core.types.scalars.*): The start of the interval. If `stop`
is specified, the start of interval (inclusive); otherwise, the end
of the interval (exclusive). If `stop` is not specified, the default
starting value is 0.
stop (numba.core.types.scalars.*, optional): The end of the interval.
Default: `None`.
step (numba.core.types.scalars.*, optional): The distance between two
adjacent elements (`out[i+1] - out[i]`). Must not be 0; may be
negative, this results in an empty array if `stop >= start`.
Default: 1.
dtype (numba.core.types.scalars.*, optional): The output array data
type. If `dtype` is `None`, the output array data type must be
inferred from `start`, `stop` and `step`. If those are all integers,
the output array `dtype` must be the default integer `dtype`; if
one or more have type `float`, then the output array dtype must be
the default real-valued floating-point data type. Default: `None`.
start (numba.core.types.scalars.*): The start of the interval.
stop (numba.core.types.scalars.*): The end of the interval. If `endpoint`
is `False`, the function must generate a sequence of `num+1` evenly
spaced numbers starting with `start` and ending with `stop` and
exclude the `stop` from the returned array such that the returned
array consists of evenly spaced numbers over the half-open interval
`[start, stop)`. If `endpoint` is `True`, the output array must
consist of evenly spaced numbers over the closed interval
`[start, stop]`. Default: `True`.
num (numba.core.types.scalars.Integer): number of samples. Must be a
nonnegative integer value.
dtype (numba.core.types.scalars.*, optional): Output array data type.
Should be a floating-point data type. If `dtype` is `None`, if
either `start` or `stop` is a complex number, the output data type
must be the default complex floating-point data type. If both `start`
and `stop` are real-valued, the output data type must be the
default real-valued floating-point data type. Default: `None`.
device (numba.core.types.misc.StringLiteral, optional): array API
concept of device where the output array is created. `device`
can be `None`, a oneAPI filter selector string, an instance of
Expand All @@ -863,13 +887,23 @@ def ol_dpnp_linspace(
one or another. If both are specified, a TypeError is raised. If
both are None, a cached queue targeting default-selected device
is used for allocation and copying. Default: `None`.
endpoint (numba.core.types.scalars.Boolean, optional): Boolean
indicating whether to include `stop` in the interval. Default: `True`
retstep (numba.core.types.scalars.Boolean, optional): Boolean
indicating whether to return the step of the interval. This is not
being used. Default: `False`
axis (numba.core.types.scalars.Integer, optional): An int specifying
which axis the intervals will be populated in the case of more than
1D array. This is not being used. Default: 0
Raises:
errors.NumbaNotImplementedError: If `start` is
`numba.core.types.scalars.Complex` type
errors.NumbaTypeError: If `start` is `numba.core.types.scalars.Boolean`
errors.NumbaNotImplementedError: If `start` or `stop` is neither of
`numba.core.types.scalars.Integer` or
`numba.core.types.scalars.Float` type. Also when `retstep` is not
`False` or `axis` is not 0.
errors.NumbaTypeError: If `num` is not `numba.core.types.scalars.Integer`
type
errors.TypingError: If couldn't parse input types to dpnp.arange().
errors.TypingError: If couldn't parse input types to dpnp.linspace().
Returns:
function: Local function `impl_dpnp_linspace()`.
Expand All @@ -885,11 +919,11 @@ def ol_dpnp_linspace(
)
):
msg = (
"Input data type is not supported."
"Input data type is not supported yet."
+ " Please convert the input to"
+ " a scalar data type (int or float)."
)
raise errors.NumbaTypeError(msg)
raise errors.NumbaNotImplementedError(msg)

if not _match_type([num], Integer, int):
msg = "'num' must be an int."
Expand Down
78 changes: 73 additions & 5 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,28 @@
0.0,
0.5,
23,
], # fails dtype=np.int32, dpnp/np results don't make sense # 10
[-0.5, 0.0, 10], # fails dtype=np.int32, dpnp/np results don't make sense
], # fails at dtype=np.int, dpnp/np results don't make sense # 10
[-0.5, 0.0, 10], # fails at dtype=np.int, dpnp/np results don't make sense
]


@pytest.mark.parametrize("range", ranges[:-2])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
def test_dpnp_linspace_default(range, dtype, endpoint):
"""Tests `dpnp.linspace()` overload with default setting.
Test over all ranges and dtypes with default settings for `dpnp.linspace()`,
except the last two. Those two fail with `dtype=np.int`.
Args:
range (list): A `list` containing `start` and `stop` of the interval.
dtype (type): A `type` to be used in the `dtype` parameter.
endpoint (bool): A boolean value to include/exclude endpoint.
Returns:
None: Nothing.
"""
start, stop, num = range

@dpjit
Expand Down Expand Up @@ -84,6 +97,19 @@ def func():
@pytest.mark.parametrize("dtype", dtypes_float_only)
@pytest.mark.parametrize("endpoint", endpoints)
def test_dpnp_linspace_default_float_only(range, dtype, endpoint):
"""Tests `dpnp.linspace()` overload with default setting.
Test over all ranges with default settings for `dpnp.linspace()`.
The `dtype` exclude all `int` types.
Args:
range (list): A `list` containing `start` and `stop` of the interval.
dtype (type): A `type` to be used in the `dtype` parameter.
endpoint (bool): A boolean value to include/exclude endpoint.
Returns:
None: Nothing.
"""
start, stop, num = range

@dpjit
Expand Down Expand Up @@ -113,11 +139,25 @@ def func():
)


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("range", ranges[:-2])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_linspace_from_device(range, dtype, endpoint, usm_type):
"""Test device only `dpnp.linspace()` overload with parameterized `usm_type`.
We are skipping the last two since they fail on `dtype=np.int` types.
Args:
range (list): A `list` containing `start` and `stop` of the interval.
dtype (type): A `type` to be used in the `dtype` parameter.
endpoint (bool): A boolean value to include/exclude endpoint.
usm_type (str): A `str` value to denote the type of USM one of
`["device", "shared", "host"]`.
Returns:
None: Nothing.
"""
device = dpctl.SyclDevice().filter_string

start, stop, num = range
Expand Down Expand Up @@ -150,11 +190,26 @@ def func():
)


@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("range", ranges[:-2])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("endpoint", endpoints)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_linspace_from_queue(range, dtype, usm_type, endpoint):
def test_dpnp_linspace_from_queue(range, dtype, endpoint, usm_type):
"""Test `dpnp.linspace()` overload with specied queue and parameterized
`usm_type`.
We are skipping the last two since they fail on `dtype=np.int` types.
Args:
range (list): A `list` containing `start` and `stop` of the interval.
dtype (type): A `type` to be used in the `dtype` parameter.
endpoint (bool): A boolean value to include/exclude endpoint.
usm_type (str): A `str` value to denote the type of USM one of
`["device", "shared", "host"]`.
Returns:
None: Nothing.
"""
start, stop, num = range

@dpjit
Expand Down Expand Up @@ -192,6 +247,19 @@ def func(queue):
def test_dpnp_linspace_default_dtype_perm(
range, start_dtype, stop_dtype, dtype, endpoint
):
"""Tests `dpnp.linspace()` overload with default setting and permutations of
different `dtype`s for `start` and `stop`.
Args:
range (list): A `list` containing `start` and `stop` of the interval.
start_dtype (type): The `dtype` for `start` value.
stop_dtype (type): The `dtype` for `stop` value.
dtype (type): A `type` to be used in the `dtype` parameter.
endpoint (bool): A boolean value to include/exclude endpoint.
Returns:
None: Nothing.
"""
start, stop, num = start_dtype(range[0]), stop_dtype(range[1]), range[2]

@dpjit
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,5 @@ def to_cmake_format(version: str):
),
"-DIS_INSTALL:BOOL={0:s}".format("TRUE" if is_install else "FALSE"),
"-DIS_DEVELOP:BOOL={0:s}".format("TRUE" if is_develop else "FALSE"),
"-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER={0:s}".format("icx" if is_windows else "icpx"),
],
)

0 comments on commit 5f5de6b

Please sign in to comment.