From f40105c9c7e460681adcdf18d53e8e20134eac6d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 14:12:41 -0600 Subject: [PATCH 01/24] feat: add all direct from NumPy/CuPy functions --- src/ragged/_import.py | 14 ++++++++++++++ src/ragged/_spec_array_object.py | 2 ++ src/ragged/_spec_creation_functions.py | 11 ++++------- tests/test_spec_creation_functions.py | 20 ++++++++++++++++++++ 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/ragged/_import.py b/src/ragged/_import.py index 05c73ef..9c56827 100644 --- a/src/ragged/_import.py +++ b/src/ragged/_import.py @@ -4,6 +4,20 @@ from typing import Any +import numpy as np + +from ._typing import Device + + +def device_namespace(device: None | Device = None) -> tuple[Device, Any]: + if device is None or device == "cpu": + return "cpu", np + elif device == "cuda": + return "cuda", cupy() + + msg = f"unrecognized device: {device!r}" # type: ignore[unreachable] + raise ValueError(msg) + def cupy() -> Any: try: diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index fd19660..b300fc5 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -1144,5 +1144,7 @@ def _box( else: dtype = dtype_observed device = "cpu" if isinstance(output, np.ndarray) else "cuda" + if shape != (): + impl = ak.Array(impl) return cls._new(impl, shape, dtype, device) # pylint: disable=W0212 diff --git a/src/ragged/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py index c6895c0..8014045 100644 --- a/src/ragged/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -8,7 +8,8 @@ import awkward as ak -from ._spec_array_object import array +from ._import import device_namespace +from ._spec_array_object import _box, array from ._typing import ( Device, Dtype, @@ -53,12 +54,8 @@ def arange( https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html """ - start # noqa: B018, pylint: disable=W0104 - stop # noqa: B018, pylint: disable=W0104 - step # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 35") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.arange(start, stop, step, dtype=dtype)) def asarray( diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index 45c1b57..51853cd 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -6,8 +6,21 @@ from __future__ import annotations +import numpy as np +import pytest + import ragged +devices = ["cpu"] +ns = {"cpu": np} +try: + import cupy as cp + + devices.append("cuda") + ns["cuda"] = cp +except ModuleNotFoundError: + cp = None + def test_existence(): assert ragged.arange is not None @@ -26,3 +39,10 @@ def test_existence(): assert ragged.triu is not None assert ragged.zeros is not None assert ragged.zeros_like is not None + + +@pytest.mark.parametrize("device", devices) +def test_arange(device): + a = ragged.arange(5, 10, 2, device=device) + assert a.tolist() == [5, 7, 9] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] From 31a37a3db4e860842ee2363bdca64052a20972ec Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 14:32:00 -0600 Subject: [PATCH 02/24] empty, eye, from_dlpack --- src/ragged/_spec_creation_functions.py | 35 +++++++++++++++++--------- tests/test_spec_creation_functions.py | 22 ++++++++++++++++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/ragged/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py index 8014045..8b1499b 100644 --- a/src/ragged/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -6,8 +6,12 @@ from __future__ import annotations +import enum + import awkward as ak +import numpy as np +from . import _import from ._import import device_namespace from ._spec_array_object import _box, array from ._typing import ( @@ -134,10 +138,8 @@ def empty( https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 36") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.empty(shape, dtype=dtype)) def empty_like( @@ -194,12 +196,8 @@ def eye( https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html """ - n_rows # noqa: B018, pylint: disable=W0104 - n_cols # noqa: B018, pylint: disable=W0104 - k # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 38") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.eye(n_rows, n_cols, k, dtype=dtype)) def from_dlpack(x: object, /) -> array: @@ -215,8 +213,21 @@ def from_dlpack(x: object, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html """ - x # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 39") # noqa: EM101 + device_type, _ = x.__dlpack_device__() # type: ignore[attr-defined] + if ( + isinstance(device_type, enum.Enum) and device_type.value == 1 + ) or device_type == 1: + y = np.from_dlpack(x) + elif ( + isinstance(device_type, enum.Enum) and device_type.value == 2 + ) or device_type == 2: + cp = _import.cupy() + y = cp.from_dlpack(x) + else: + msg = f"unsupported __dlpack_device__ type: {device_type}" + raise TypeError(msg) + + return _box(array, y) def full( diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index 51853cd..d63664b 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -46,3 +46,25 @@ def test_arange(device): a = ragged.arange(5, 10, 2, device=device) assert a.tolist() == [5, 7, 9] assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_empty(device): + a = ragged.empty((2, 3, 5), device=device) + assert a.shape == (2, 3, 5) + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_eye(device): + a = ragged.eye(3, 5, k=1, device=device) + assert a.tolist() == [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_from_dlpack(device): + a = ns[device].array([1, 2, 3, 4, 5]) + b = ragged.from_dlpack(a) + assert b.tolist() == [1, 2, 3, 4, 5] + assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] From b058476b2f5535f6165e4bc31b65a7b6a31436b1 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 14:41:55 -0600 Subject: [PATCH 03/24] full, linspace, ones, zeros --- src/ragged/_spec_creation_functions.py | 30 +++++++++----------------- tests/test_spec_creation_functions.py | 28 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/ragged/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py index 8b1499b..189a787 100644 --- a/src/ragged/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -262,11 +262,8 @@ def full( https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html """ - shape # noqa: B018, pylint: disable=W0104 - fill_value # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 40") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.full(shape, fill_value, dtype=dtype)) def full_like( @@ -352,13 +349,10 @@ def linspace( https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html """ - start # noqa: B018, pylint: disable=W0104 - stop # noqa: B018, pylint: disable=W0104 - num # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - endpoint # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 42") # noqa: EM101 + device, ns = device_namespace(device) + return _box( + array, ns.linspace(start, stop, num=num, endpoint=endpoint, dtype=dtype) + ) def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: @@ -423,10 +417,8 @@ def ones( https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 44") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.ones(shape, dtype=dtype)) def ones_like( @@ -526,10 +518,8 @@ def zeros( https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html """ - shape # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - device # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 48") # noqa: EM101 + device, ns = device_namespace(device) + return _box(array, ns.zeros(shape, dtype=dtype)) def zeros_like( diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index d63664b..32db912 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -68,3 +68,31 @@ def test_from_dlpack(device): b = ragged.from_dlpack(a) assert b.tolist() == [1, 2, 3, 4, 5] assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_full(device): + a = ragged.full(5, 3, device=device) + assert a.tolist() == [3, 3, 3, 3, 3] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_linspace(device): + a = ragged.linspace(5, 8, 5, device=device) + assert a.tolist() == [5, 5.75, 6.5, 7.25, 8] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_ones(device): + a = ragged.ones(5, device=device) + assert a.tolist() == [1, 1, 1, 1, 1] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_zeros(device): + a = ragged.zeros(5, device=device) + assert a.tolist() == [0, 0, 0, 0, 0] + assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] From bf6b11224334e8c14bf63d091d1e42c05be9e066 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 14:45:05 -0600 Subject: [PATCH 04/24] Also test cases that create scalars. --- tests/test_spec_creation_functions.py | 35 +++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index 32db912..7b83798 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -55,6 +55,14 @@ def test_empty(device): assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] +@pytest.mark.parametrize("device", devices) +def test_empty_ndim0(device): + a = ragged.empty((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert isinstance(a._impl, ns[device].ndarray) + + @pytest.mark.parametrize("device", devices) def test_eye(device): a = ragged.eye(3, 5, k=1, device=device) @@ -77,6 +85,15 @@ def test_full(device): assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] +@pytest.mark.parametrize("device", devices) +def test_full_ndim0(device): + a = ragged.full((), 3, device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 3 + assert isinstance(a._impl, ns[device].ndarray) + + @pytest.mark.parametrize("device", devices) def test_linspace(device): a = ragged.linspace(5, 8, 5, device=device) @@ -91,8 +108,26 @@ def test_ones(device): assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] +@pytest.mark.parametrize("device", devices) +def test_ones_ndim0(device): + a = ragged.ones((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 1 + assert isinstance(a._impl, ns[device].ndarray) + + @pytest.mark.parametrize("device", devices) def test_zeros(device): a = ragged.zeros(5, device=device) assert a.tolist() == [0, 0, 0, 0, 0] assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] + + +@pytest.mark.parametrize("device", devices) +def test_zeros_ndim0(device): + a = ragged.zeros((), device=device) + assert a.ndim == 0 + assert a.shape == () + assert a == 0 + assert isinstance(a._impl, ns[device].ndarray) From ed0e08487da72b190c3e9f9a1c0712bf396a236e Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:07:02 -0600 Subject: [PATCH 05/24] can_cast, finfo, iinfo --- src/ragged/_spec_data_type_functions.py | 26 +++++++++++---- tests/test_spec_data_type_functions.py | 43 +++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index fb61fdf..2c12edf 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -56,9 +56,7 @@ def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html """ - from_ # noqa: B018, pylint: disable=W0104 - to # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 51") # noqa: EM101 + return bool(np.can_cast(from_, to)) @dataclass @@ -114,8 +112,16 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html """ - type # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 52") # noqa: EM101 + if not isinstance(type, np.dtype): + if not isinstance(type, __builtins__["type"]) and hasattr(type, "dtype"): # type: ignore[index] + out = np.finfo(type.dtype) + else: + out = np.finfo(np.dtype(type)) + else: + out = np.finfo(type) + return finfo_object( + out.bits, out.eps, out.max, out.min, out.smallest_normal, out.dtype + ) @dataclass @@ -155,8 +161,14 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html """ - type # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 53") # noqa: EM101 + if not isinstance(type, np.dtype): + if not isinstance(type, __builtins__["type"]) and hasattr(type, "dtype"): # type: ignore[index] + out = np.iinfo(type.dtype) + else: + out = np.iinfo(np.dtype(type)) + else: + out = np.iinfo(type) + return iinfo_object(out.bits, out.max, out.min, out.dtype) def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: diff --git a/tests/test_spec_data_type_functions.py b/tests/test_spec_data_type_functions.py index b31bc52..3e0e785 100644 --- a/tests/test_spec_data_type_functions.py +++ b/tests/test_spec_data_type_functions.py @@ -6,6 +6,8 @@ from __future__ import annotations +import numpy as np + import ragged @@ -16,3 +18,44 @@ def test_existence(): assert ragged.iinfo is not None assert ragged.isdtype is not None assert ragged.result_type is not None + + +def test_can_cast(): + assert ragged.can_cast(np.float32, np.complex128) + assert not ragged.can_cast(np.complex128, np.float32) + + +def test_finfo(): + f = ragged.finfo(np.float64) + assert f.bits == 64 + assert f.eps == 2.220446049250313e-16 + assert f.max == 1.7976931348623157e308 + assert f.min == -1.7976931348623157e308 + assert f.smallest_normal == 2.2250738585072014e-308 + assert f.dtype == np.dtype(np.float64) + + +def test_finfo_array(): + f = ragged.finfo(ragged.array([1.1, 2.2, 3.3])) + assert f.bits == 64 + assert f.eps == 2.220446049250313e-16 + assert f.max == 1.7976931348623157e308 + assert f.min == -1.7976931348623157e308 + assert f.smallest_normal == 2.2250738585072014e-308 + assert f.dtype == np.dtype(np.float64) + + +def test_iinfo(): + f = ragged.iinfo(np.int16) + assert f.bits == 16 + assert f.max == 32767 + assert f.min == -32768 + assert f.dtype == np.dtype(np.int16) + + +def test_iinfo_array(): + f = ragged.iinfo(np.array([1, 2, 3], np.int16)) + assert f.bits == 16 + assert f.max == 32767 + assert f.min == -32768 + assert f.dtype == np.dtype(np.int16) From feffcb2fe9cbbf7c834996c7ac1cd4b4e0d114e6 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:16:44 -0600 Subject: [PATCH 06/24] result_type --- src/ragged/_spec_data_type_functions.py | 9 +++++---- tests/test_spec_data_type_functions.py | 23 +++++++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index 2c12edf..ac67da4 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -13,6 +13,8 @@ from ._spec_array_object import array from ._typing import Dtype +_type = type + def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: """ @@ -113,7 +115,7 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 """ if not isinstance(type, np.dtype): - if not isinstance(type, __builtins__["type"]) and hasattr(type, "dtype"): # type: ignore[index] + if not isinstance(type, _type) and hasattr(type, "dtype"): # type: ignore[index] out = np.finfo(type.dtype) else: out = np.finfo(np.dtype(type)) @@ -162,7 +164,7 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 """ if not isinstance(type, np.dtype): - if not isinstance(type, __builtins__["type"]) and hasattr(type, "dtype"): # type: ignore[index] + if not isinstance(type, _type) and hasattr(type, "dtype"): # type: ignore[index] out = np.iinfo(type.dtype) else: out = np.iinfo(np.dtype(type)) @@ -230,5 +232,4 @@ def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html """ - arrays_and_dtypes # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 55") # noqa: EM101 + return np.result_type(*arrays_and_dtypes) diff --git a/tests/test_spec_data_type_functions.py b/tests/test_spec_data_type_functions.py index 3e0e785..526181f 100644 --- a/tests/test_spec_data_type_functions.py +++ b/tests/test_spec_data_type_functions.py @@ -36,12 +36,14 @@ def test_finfo(): def test_finfo_array(): + f = ragged.finfo(np.array([1.1, 2.2, 3.3])) + assert f.bits == 64 + assert f.dtype == np.dtype(np.float64) + + +def test_finfo_array2(): f = ragged.finfo(ragged.array([1.1, 2.2, 3.3])) assert f.bits == 64 - assert f.eps == 2.220446049250313e-16 - assert f.max == 1.7976931348623157e308 - assert f.min == -1.7976931348623157e308 - assert f.smallest_normal == 2.2250738585072014e-308 assert f.dtype == np.dtype(np.float64) @@ -56,6 +58,15 @@ def test_iinfo(): def test_iinfo_array(): f = ragged.iinfo(np.array([1, 2, 3], np.int16)) assert f.bits == 16 - assert f.max == 32767 - assert f.min == -32768 assert f.dtype == np.dtype(np.int16) + + +def test_iinfo_array2(): + f = ragged.iinfo(ragged.array([1, 2, 3], np.int16)) + assert f.bits == 16 + assert f.dtype == np.dtype(np.int16) + + +def test_result_type(): + dt = ragged.result_type(ragged.array([1, 2, 3]), ragged.array([1.1, 2.2, 3.3])) + assert dt == np.dtype(np.float64) From 5b2d4a9cb0b8cd253eb5bb83b942375be265c10e Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:23:28 -0600 Subject: [PATCH 07/24] unnecessary mypy ignore --- src/ragged/_spec_data_type_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index ac67da4..4106ef8 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -115,7 +115,7 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 """ if not isinstance(type, np.dtype): - if not isinstance(type, _type) and hasattr(type, "dtype"): # type: ignore[index] + if not isinstance(type, _type) and hasattr(type, "dtype"): out = np.finfo(type.dtype) else: out = np.finfo(np.dtype(type)) @@ -164,7 +164,7 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 """ if not isinstance(type, np.dtype): - if not isinstance(type, _type) and hasattr(type, "dtype"): # type: ignore[index] + if not isinstance(type, _type) and hasattr(type, "dtype"): out = np.iinfo(type.dtype) else: out = np.iinfo(np.dtype(type)) From 89ca7e2b5445daab79b0f9dfb9aa62a53ab3cd2c Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:23:47 -0600 Subject: [PATCH 08/24] test the minimal NumPy version --- .github/workflows/ci.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fc948f..16a8430 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,10 +41,17 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.12"] + numpy-version: ["latest"] runs-on: [ubuntu-latest, macos-latest, windows-latest] include: - python-version: pypy-3.10 + numpy-version: latest + runs-on: ubuntu-latest + + include: + - python-version: 3.9 + numpy-version: 1.18.0 runs-on: ubuntu-latest steps: @@ -57,9 +64,16 @@ jobs: python-version: ${{ matrix.python-version }} allow-prereleases: true + - name: Install old NumPy + if: matrix.numpy-version != 'latest' + run: python -m pip install numpy==${{ matrix.numpy-version }} + - name: Install package run: python -m pip install .[test] + - name: Print NumPy version + run: python -c 'import numpy as np; print(np.__version__)' + - name: Test package run: >- python -m pytest -ra --cov --cov-report=xml --cov-report=term From aeac702b612b6c6355356c433194765e95f23a3d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:29:17 -0600 Subject: [PATCH 09/24] fix YAML --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 16a8430..cb7e2c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,13 +45,13 @@ jobs: runs-on: [ubuntu-latest, macos-latest, windows-latest] include: - - python-version: pypy-3.10 - numpy-version: latest + - python-version: "pypy-3.10" + numpy-version: "latest" runs-on: ubuntu-latest include: - - python-version: 3.9 - numpy-version: 1.18.0 + - python-version: "3.9" + numpy-version: "1.18.0" runs-on: ubuntu-latest steps: From efb5873b46130e63db399c97ccffc1d3149839c0 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:31:13 -0600 Subject: [PATCH 10/24] fix YAML 2 --- .github/workflows/ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb7e2c4..ee15053 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,8 +48,6 @@ jobs: - python-version: "pypy-3.10" numpy-version: "latest" runs-on: ubuntu-latest - - include: - python-version: "3.9" numpy-version: "1.18.0" runs-on: ubuntu-latest From 91e51dcda4a15ccc8f39ec78aa29a5e2d140e3cc Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:35:52 -0600 Subject: [PATCH 11/24] fix YAML names and PIP_ONLY_BINARY --- .github/workflows/ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee15053..2768d52 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: pipx run nox -s pylint checks: - name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + name: py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ matrix.runs-on }} runs-on: ${{ matrix.runs-on }} needs: [pre-commit] strategy: @@ -52,6 +52,9 @@ jobs: numpy-version: "1.18.0" runs-on: ubuntu-latest + env: + PIP_ONLY_BINARY: numpy + steps: - uses: actions/checkout@v4 with: From 47753f276eeffc978e436120d5630a0ed90391f8 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:36:49 -0600 Subject: [PATCH 12/24] fix YAML 3 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2768d52..21b197b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: pipx run nox -s pylint checks: - name: py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ matrix.runs-on }} + name: "py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ matrix.runs-on }}" runs-on: ${{ matrix.runs-on }} needs: [pre-commit] strategy: From d132098898c83c04f035819031eab7fcde72b8f2 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:38:09 -0600 Subject: [PATCH 13/24] fix YAML 4 --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21b197b..8e5bc0c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,9 @@ jobs: pipx run nox -s pylint checks: - name: "py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ matrix.runs-on }}" + name: + "py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{ + matrix.runs-on }}" runs-on: ${{ matrix.runs-on }} needs: [pre-commit] strategy: From 4aac97c4f322886e580193a875953bb60abd8f5d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:42:24 -0600 Subject: [PATCH 14/24] Old Ubuntu image should have old NumPy, I hope. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e5bc0c..6bbed3c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,7 @@ jobs: runs-on: ubuntu-latest - python-version: "3.9" numpy-version: "1.18.0" - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 env: PIP_ONLY_BINARY: numpy From efe63ac3f1b49392b61f8b46dfd59e57df4b8cb3 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:51:20 -0600 Subject: [PATCH 15/24] Take a lower minimum Python (3.8) to get NumPy 1.18.0. --- .github/workflows/ci.yml | 4 ++-- pyproject.toml | 7 ++++--- src/ragged/_spec_array_object.py | 6 +++--- src/ragged/_typing.py | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6bbed3c..11ac542 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.12"] + python-version: ["3.8", "3.12"] numpy-version: ["latest"] runs-on: [ubuntu-latest, macos-latest, windows-latest] @@ -50,7 +50,7 @@ jobs: - python-version: "pypy-3.10" numpy-version: "latest" runs-on: ubuntu-latest - - python-version: "3.9" + - python-version: "3.8" numpy-version: "1.18.0" runs-on: ubuntu-20.04 diff --git a/pyproject.toml b/pyproject.toml index 1470efa..ee7f0d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "Ragged array library, complying with Python API specification." readme = "README.md" license.file = "LICENSE" -requires-python = ">=3.9" +requires-python = ">=3.8" classifiers = [ "Development Status :: 1 - Planning", "Intended Audience :: Science/Research", @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -88,7 +89,7 @@ report.exclude_also = [ [tool.mypy] files = ["src", "tests"] -python_version = "3.9" +python_version = "3.8" warn_unused_configs = true strict = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] @@ -157,7 +158,7 @@ isort.required-imports = ["from __future__ import annotations"] [tool.pylint] -py-version = "3.9" +py-version = "3.8" ignore-paths = [".*/_version.py"] reports.output-format = "colorized" similarities.ignore-imports = "yes" diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index b300fc5..b7f5894 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -8,7 +8,7 @@ import enum import numbers -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Tuple, Union import awkward as ak import numpy as np @@ -68,12 +68,12 @@ class ellipsis(Enum): # pylint: disable=C0103 slice, ellipsis, None, - tuple[Union[int, slice, ellipsis, None], ...], + Tuple[Union[int, slice, ellipsis, None], ...], "array", ] SetSliceKey = Union[ - int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], "array" + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], "array" ] diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index 0b2c675..49cce30 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, TypeVar, Union +from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar, Union import numpy as np @@ -43,7 +43,7 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: ... -Shape = tuple[Optional[int], ...] +Shape = Tuple[Optional[int], ...] Dtype = np.dtype[ Union[ From 1fad6020178fa2c1d33af8a32af4a2fe505ffa5a Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:54:22 -0600 Subject: [PATCH 16/24] Don't subscript np.dtype. --- src/ragged/_typing.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index 49cce30..7bd97d9 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar, Union +from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar import numpy as np @@ -45,23 +45,7 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: Shape = Tuple[Optional[int], ...] -Dtype = np.dtype[ - Union[ - np.bool_, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.float32, - np.float64, - np.complex64, - np.complex128, - ] -] +Dtype = np.dtype numeric_types = ( np.bool_, From 5c21ae470ebc6b3f54305d19091c5ac27fd3781e Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:59:33 -0600 Subject: [PATCH 17/24] Revert "Don't subscript np.dtype." This reverts commit 1fad6020178fa2c1d33af8a32af4a2fe505ffa5a. --- src/ragged/_typing.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index 7bd97d9..49cce30 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar +from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar, Union import numpy as np @@ -45,7 +45,23 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: Shape = Tuple[Optional[int], ...] -Dtype = np.dtype +Dtype = np.dtype[ + Union[ + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float32, + np.float64, + np.complex64, + np.complex128, + ] +] numeric_types = ( np.bool_, From ed915b8c291f46d97618568e21068edffca6de04 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 15:59:45 -0600 Subject: [PATCH 18/24] Revert "Take a lower minimum Python (3.8) to get NumPy 1.18.0." This reverts commit efe63ac3f1b49392b61f8b46dfd59e57df4b8cb3. --- .github/workflows/ci.yml | 4 ++-- pyproject.toml | 7 +++---- src/ragged/_spec_array_object.py | 6 +++--- src/ragged/_typing.py | 4 ++-- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 11ac542..6bbed3c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.12"] + python-version: ["3.9", "3.12"] numpy-version: ["latest"] runs-on: [ubuntu-latest, macos-latest, windows-latest] @@ -50,7 +50,7 @@ jobs: - python-version: "pypy-3.10" numpy-version: "latest" runs-on: ubuntu-latest - - python-version: "3.8" + - python-version: "3.9" numpy-version: "1.18.0" runs-on: ubuntu-20.04 diff --git a/pyproject.toml b/pyproject.toml index ee7f0d2..1470efa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "Ragged array library, complying with Python API specification." readme = "README.md" license.file = "LICENSE" -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Development Status :: 1 - Planning", "Intended Audience :: Science/Research", @@ -21,7 +21,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -89,7 +88,7 @@ report.exclude_also = [ [tool.mypy] files = ["src", "tests"] -python_version = "3.8" +python_version = "3.9" warn_unused_configs = true strict = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] @@ -158,7 +157,7 @@ isort.required-imports = ["from __future__ import annotations"] [tool.pylint] -py-version = "3.8" +py-version = "3.9" ignore-paths = [".*/_version.py"] reports.output-format = "colorized" similarities.ignore-imports = "yes" diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index b7f5894..b300fc5 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -8,7 +8,7 @@ import enum import numbers -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Union import awkward as ak import numpy as np @@ -68,12 +68,12 @@ class ellipsis(Enum): # pylint: disable=C0103 slice, ellipsis, None, - Tuple[Union[int, slice, ellipsis, None], ...], + tuple[Union[int, slice, ellipsis, None], ...], "array", ] SetSliceKey = Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], "array" + int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], "array" ] diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index 49cce30..0b2c675 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar, Union +from typing import Any, Literal, Optional, Protocol, TypeVar, Union import numpy as np @@ -43,7 +43,7 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: ... -Shape = Tuple[Optional[int], ...] +Shape = tuple[Optional[int], ...] Dtype = np.dtype[ Union[ From 5b15de8559ac7f17e5e7758224e42c04d569a5fb Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:02:51 -0600 Subject: [PATCH 19/24] Nope. Instead, increase the minimum NumPy to 1.19.3. --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6bbed3c..4bc24e8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,8 +51,8 @@ jobs: numpy-version: "latest" runs-on: ubuntu-latest - python-version: "3.9" - numpy-version: "1.18.0" - runs-on: ubuntu-20.04 + numpy-version: "1.19.3" + runs-on: ubuntu-latest env: PIP_ONLY_BINARY: numpy From 42f7fd7b98e6049202b87576d243c83496e5f69d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:04:41 -0600 Subject: [PATCH 20/24] Don't subscript np.dtype. --- src/ragged/_typing.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index 0b2c675..c40eb4b 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -45,23 +45,7 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: Shape = tuple[Optional[int], ...] -Dtype = np.dtype[ - Union[ - np.bool_, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.float32, - np.float64, - np.complex64, - np.complex128, - ] -] +Dtype = np.dtype numeric_types = ( np.bool_, From 32786ae250035914e1ef056b65443c19a7093804 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:05:21 -0600 Subject: [PATCH 21/24] pre-commit --- src/ragged/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index c40eb4b..eb245c1 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, TypeVar, Union +from typing import Any, Literal, Optional, Protocol, TypeVar import numpy as np From 8c92f30d9c6b5ab5650149a9cd7b53b4821c1904 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:19:23 -0600 Subject: [PATCH 22/24] Create a fake numpy.array_api for tests of np 1.19.3. --- .github/workflows/ci.yml | 3 - tests/test_spec_elementwise_functions.py | 75 +++++++++++++++++++++++- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4bc24e8..9178c3c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,9 +54,6 @@ jobs: numpy-version: "1.19.3" runs-on: ubuntu-latest - env: - PIP_ONLY_BINARY: numpy - steps: - uses: actions/checkout@v4 with: diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 089443e..d8bde1b 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -6,6 +6,7 @@ from __future__ import annotations +import types import warnings from typing import Any @@ -14,7 +15,73 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import numpy.array_api as xp + try: + import numpy.array_api as xp + + real_xp = True + except ModuleNotFoundError: + real_xp = False + xp = types.ModuleType("array_api") + xp.asarray = np.asarray + xp.abs = np.abs + xp.acos = np.arccos + xp.acosh = np.arccosh + xp.add = np.add + xp.asin = np.arcsin + xp.asinh = np.arcsinh + xp.atan = np.arctan + xp.atan2 = np.arctan2 + xp.atanh = np.arctanh + xp.bitwise_and = np.bitwise_and + xp.bitwise_invert = np.invert + xp.bitwise_left_shift = np.left_shift + xp.bitwise_or = np.bitwise_or + xp.bitwise_right_shift = np.right_shift + xp.bitwise_xor = np.bitwise_xor + xp.ceil = np.ceil + xp.conj = np.conj + xp.cos = np.cos + xp.cosh = np.cosh + xp.divide = np.divide + xp.equal = np.equal + xp.exp = np.exp + xp.expm1 = np.expm1 + xp.floor = np.floor + xp.floor_divide = np.floor_divide + xp.greater = np.greater + xp.greater_equal = np.greater_equal + xp.imag = np.imag + xp.isfinite = np.isfinite + xp.isinf = np.isinf + xp.isnan = np.isnan + xp.less = np.less + xp.less_equal = np.less_equal + xp.log = np.log + xp.log1p = np.log1p + xp.log2 = np.log2 + xp.log10 = np.log10 + xp.logaddexp = np.logaddexp + xp.logical_and = np.logical_and + xp.logical_not = np.logical_not + xp.logical_or = np.logical_or + xp.logical_xor = np.logical_xor + xp.multiply = np.multiply + xp.negative = np.negative + xp.not_equal = np.not_equal + xp.positive = np.positive + xp.pow = np.power + xp.real = np.real + xp.remainder = np.remainder + xp.round = np.round + xp.sign = np.sign + xp.sin = np.sin + xp.sinh = np.sinh + xp.square = np.square + xp.sqrt = np.sqrt + xp.subtract = np.subtract + xp.tan = np.tan + xp.tanh = np.tanh + xp.trunc = np.trunc import pytest @@ -437,7 +504,8 @@ def test_ceil_int(device, x_int): assert type(result) is type(x_int) assert result.shape == x_int.shape assert xp.ceil(first(x_int)) == first(result) - assert xp.ceil(first(x_int)).dtype == result.dtype + if real_xp: + assert xp.ceil(first(x_int)).dtype == result.dtype @pytest.mark.parametrize("device", devices) @@ -546,7 +614,8 @@ def test_floor_int(device, x_int): assert type(result) is type(x_int) assert result.shape == x_int.shape assert xp.floor(first(x_int)) == first(result) - assert xp.floor(first(x_int)).dtype == result.dtype + if real_xp: + assert xp.floor(first(x_int)).dtype == result.dtype @pytest.mark.parametrize("device", devices) From 57f77a7e5a1447721b41deaa239ba58002033d75 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:33:46 -0600 Subject: [PATCH 23/24] NumPy started experimental support for array_api with 1.22.0. --- .github/workflows/ci.yml | 2 +- src/ragged/_typing.py | 20 ++++++- tests/test_spec_elementwise_functions.py | 75 +----------------------- 3 files changed, 22 insertions(+), 75 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9178c3c..2636bcd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: numpy-version: "latest" runs-on: ubuntu-latest - python-version: "3.9" - numpy-version: "1.19.3" + numpy-version: "1.22.0" runs-on: ubuntu-latest steps: diff --git a/src/ragged/_typing.py b/src/ragged/_typing.py index eb245c1..0b2c675 100644 --- a/src/ragged/_typing.py +++ b/src/ragged/_typing.py @@ -8,7 +8,7 @@ import enum import sys -from typing import Any, Literal, Optional, Protocol, TypeVar +from typing import Any, Literal, Optional, Protocol, TypeVar, Union import numpy as np @@ -45,7 +45,23 @@ def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: Shape = tuple[Optional[int], ...] -Dtype = np.dtype +Dtype = np.dtype[ + Union[ + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float32, + np.float64, + np.complex64, + np.complex128, + ] +] numeric_types = ( np.bool_, diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index d8bde1b..089443e 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -6,7 +6,6 @@ from __future__ import annotations -import types import warnings from typing import Any @@ -15,73 +14,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - try: - import numpy.array_api as xp - - real_xp = True - except ModuleNotFoundError: - real_xp = False - xp = types.ModuleType("array_api") - xp.asarray = np.asarray - xp.abs = np.abs - xp.acos = np.arccos - xp.acosh = np.arccosh - xp.add = np.add - xp.asin = np.arcsin - xp.asinh = np.arcsinh - xp.atan = np.arctan - xp.atan2 = np.arctan2 - xp.atanh = np.arctanh - xp.bitwise_and = np.bitwise_and - xp.bitwise_invert = np.invert - xp.bitwise_left_shift = np.left_shift - xp.bitwise_or = np.bitwise_or - xp.bitwise_right_shift = np.right_shift - xp.bitwise_xor = np.bitwise_xor - xp.ceil = np.ceil - xp.conj = np.conj - xp.cos = np.cos - xp.cosh = np.cosh - xp.divide = np.divide - xp.equal = np.equal - xp.exp = np.exp - xp.expm1 = np.expm1 - xp.floor = np.floor - xp.floor_divide = np.floor_divide - xp.greater = np.greater - xp.greater_equal = np.greater_equal - xp.imag = np.imag - xp.isfinite = np.isfinite - xp.isinf = np.isinf - xp.isnan = np.isnan - xp.less = np.less - xp.less_equal = np.less_equal - xp.log = np.log - xp.log1p = np.log1p - xp.log2 = np.log2 - xp.log10 = np.log10 - xp.logaddexp = np.logaddexp - xp.logical_and = np.logical_and - xp.logical_not = np.logical_not - xp.logical_or = np.logical_or - xp.logical_xor = np.logical_xor - xp.multiply = np.multiply - xp.negative = np.negative - xp.not_equal = np.not_equal - xp.positive = np.positive - xp.pow = np.power - xp.real = np.real - xp.remainder = np.remainder - xp.round = np.round - xp.sign = np.sign - xp.sin = np.sin - xp.sinh = np.sinh - xp.square = np.square - xp.sqrt = np.sqrt - xp.subtract = np.subtract - xp.tan = np.tan - xp.tanh = np.tanh - xp.trunc = np.trunc + import numpy.array_api as xp import pytest @@ -504,8 +437,7 @@ def test_ceil_int(device, x_int): assert type(result) is type(x_int) assert result.shape == x_int.shape assert xp.ceil(first(x_int)) == first(result) - if real_xp: - assert xp.ceil(first(x_int)).dtype == result.dtype + assert xp.ceil(first(x_int)).dtype == result.dtype @pytest.mark.parametrize("device", devices) @@ -614,8 +546,7 @@ def test_floor_int(device, x_int): assert type(result) is type(x_int) assert result.shape == x_int.shape assert xp.floor(first(x_int)) == first(result) - if real_xp: - assert xp.floor(first(x_int)).dtype == result.dtype + assert xp.floor(first(x_int)).dtype == result.dtype @pytest.mark.parametrize("device", devices) From 9751fcf19cd70077890b23a79101aec7cd06dfa2 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 16:57:05 -0600 Subject: [PATCH 24/24] Original NumPy Array API support did not include complex. --- tests/test_spec_creation_functions.py | 3 +++ tests/test_spec_elementwise_functions.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/test_spec_creation_functions.py b/tests/test_spec_creation_functions.py index 7b83798..a1c08b6 100644 --- a/tests/test_spec_creation_functions.py +++ b/tests/test_spec_creation_functions.py @@ -70,6 +70,9 @@ def test_eye(device): assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr] +@pytest.mark.skipif( + not hasattr(np, "from_dlpack"), reason=f"np.from_dlpack not in {np.__version__}" +) @pytest.mark.parametrize("device", devices) def test_from_dlpack(device): a = ns[device].array([1, 2, 3, 4, 5]) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 089443e..3838778 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -440,6 +440,10 @@ def test_ceil_int(device, x_int): assert xp.ceil(first(x_int)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_conj(device, x_complex): result = ragged.conj(x_complex.to_device(device)) @@ -623,6 +627,10 @@ def test_greater_equal_method(device, x, y): assert xp.greater_equal(first(x), first(y)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_imag(device, x_complex): result = ragged.imag(x_complex.to_device(device)) @@ -886,6 +894,10 @@ def test_pow_inplace_method(device, x, y): assert x.dtype == z.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_real(device, x_complex): result = ragged.real(x_complex.to_device(device)) @@ -932,6 +944,10 @@ def test_round(device, x): assert xp.round(first(x)).dtype == result.dtype +@pytest.mark.skipif( + np.dtype("complex128") not in xp._dtypes._all_dtypes, + reason=f"complex not allowed in np.array_api version {np.__version__}", +) @pytest.mark.parametrize("device", devices) def test_round_complex(device, x_complex): result = ragged.round(x_complex.to_device(device))