Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add all direct from NumPy/CuPy functions #10

Merged
merged 24 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f40105c
feat: add all direct from NumPy/CuPy functions
jpivarski Jan 9, 2024
31a37a3
empty, eye, from_dlpack
jpivarski Jan 9, 2024
b058476
full, linspace, ones, zeros
jpivarski Jan 9, 2024
bf6b112
Also test cases that create scalars.
jpivarski Jan 9, 2024
ed0e084
can_cast, finfo, iinfo
jpivarski Jan 9, 2024
feffcb2
result_type
jpivarski Jan 9, 2024
5b2d4a9
unnecessary mypy ignore
jpivarski Jan 9, 2024
89ca7e2
test the minimal NumPy version
jpivarski Jan 9, 2024
aeac702
fix YAML
jpivarski Jan 9, 2024
efb5873
fix YAML 2
jpivarski Jan 9, 2024
91e51dc
fix YAML names and PIP_ONLY_BINARY
jpivarski Jan 9, 2024
47753f2
fix YAML 3
jpivarski Jan 9, 2024
d132098
fix YAML 4
jpivarski Jan 9, 2024
4aac97c
Old Ubuntu image should have old NumPy, I hope.
jpivarski Jan 9, 2024
efe63ac
Take a lower minimum Python (3.8) to get NumPy 1.18.0.
jpivarski Jan 9, 2024
1fad602
Don't subscript np.dtype.
jpivarski Jan 9, 2024
5c21ae4
Revert "Don't subscript np.dtype."
jpivarski Jan 9, 2024
ed915b8
Revert "Take a lower minimum Python (3.8) to get NumPy 1.18.0."
jpivarski Jan 9, 2024
5b15de8
Nope. Instead, increase the minimum NumPy to 1.19.3.
jpivarski Jan 9, 2024
42f7fd7
Don't subscript np.dtype.
jpivarski Jan 9, 2024
32786ae
pre-commit
jpivarski Jan 9, 2024
8c92f30
Create a fake numpy.array_api for tests of np 1.19.3.
jpivarski Jan 9, 2024
57f77a7
NumPy started experimental support for array_api with 1.22.0.
jpivarski Jan 9, 2024
9751fcf
Original NumPy Array API support did not include complex.
jpivarski Jan 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,24 @@ jobs:
pipx run nox -s pylint

checks:
name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
name:
"py:${{ matrix.python-version }} np:${{ matrix.numpy-version }} os:${{
matrix.runs-on }}"
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit]
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.12"]
numpy-version: ["latest"]
runs-on: [ubuntu-latest, macos-latest, windows-latest]

include:
- python-version: pypy-3.10
- python-version: "pypy-3.10"
numpy-version: "latest"
runs-on: ubuntu-latest
- python-version: "3.9"
numpy-version: "1.22.0"
runs-on: ubuntu-latest

steps:
Expand All @@ -57,9 +64,16 @@ jobs:
python-version: ${{ matrix.python-version }}
allow-prereleases: true

- name: Install old NumPy
if: matrix.numpy-version != 'latest'
run: python -m pip install numpy==${{ matrix.numpy-version }}

- name: Install package
run: python -m pip install .[test]

- name: Print NumPy version
run: python -c 'import numpy as np; print(np.__version__)'

- name: Test package
run: >-
python -m pytest -ra --cov --cov-report=xml --cov-report=term
Expand Down
14 changes: 14 additions & 0 deletions src/ragged/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

from typing import Any

import numpy as np

from ._typing import Device


def device_namespace(device: None | Device = None) -> tuple[Device, Any]:
if device is None or device == "cpu":
return "cpu", np
elif device == "cuda":
return "cuda", cupy()

msg = f"unrecognized device: {device!r}" # type: ignore[unreachable]
raise ValueError(msg)


def cupy() -> Any:
try:
Expand Down
2 changes: 2 additions & 0 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,5 +1144,7 @@ def _box(
else:
dtype = dtype_observed
device = "cpu" if isinstance(output, np.ndarray) else "cuda"
if shape != ():
impl = ak.Array(impl)

return cls._new(impl, shape, dtype, device) # pylint: disable=W0212
76 changes: 37 additions & 39 deletions src/ragged/_spec_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

from __future__ import annotations

import enum

import awkward as ak
import numpy as np

from ._spec_array_object import array
from . import _import
from ._import import device_namespace
from ._spec_array_object import _box, array
from ._typing import (
Device,
Dtype,
Expand Down Expand Up @@ -53,12 +58,8 @@ def arange(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html
"""

start # noqa: B018, pylint: disable=W0104
stop # noqa: B018, pylint: disable=W0104
step # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 35") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.arange(start, stop, step, dtype=dtype))


def asarray(
Expand Down Expand Up @@ -137,10 +138,8 @@ def empty(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 36") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.empty(shape, dtype=dtype))


def empty_like(
Expand Down Expand Up @@ -197,12 +196,8 @@ def eye(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html
"""

n_rows # noqa: B018, pylint: disable=W0104
n_cols # noqa: B018, pylint: disable=W0104
k # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 38") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.eye(n_rows, n_cols, k, dtype=dtype))


def from_dlpack(x: object, /) -> array:
Expand All @@ -218,8 +213,21 @@ def from_dlpack(x: object, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 39") # noqa: EM101
device_type, _ = x.__dlpack_device__() # type: ignore[attr-defined]
if (
isinstance(device_type, enum.Enum) and device_type.value == 1
) or device_type == 1:
y = np.from_dlpack(x)
elif (
isinstance(device_type, enum.Enum) and device_type.value == 2
) or device_type == 2:
cp = _import.cupy()
y = cp.from_dlpack(x)
else:
msg = f"unsupported __dlpack_device__ type: {device_type}"
raise TypeError(msg)

return _box(array, y)


def full(
Expand Down Expand Up @@ -254,11 +262,8 @@ def full(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html
"""

shape # noqa: B018, pylint: disable=W0104
fill_value # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 40") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.full(shape, fill_value, dtype=dtype))


def full_like(
Expand Down Expand Up @@ -344,13 +349,10 @@ def linspace(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html
"""

start # noqa: B018, pylint: disable=W0104
stop # noqa: B018, pylint: disable=W0104
num # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
endpoint # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 42") # noqa: EM101
device, ns = device_namespace(device)
return _box(
array, ns.linspace(start, stop, num=num, endpoint=endpoint, dtype=dtype)
)


def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]:
Expand Down Expand Up @@ -415,10 +417,8 @@ def ones(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 44") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.ones(shape, dtype=dtype))


def ones_like(
Expand Down Expand Up @@ -518,10 +518,8 @@ def zeros(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html
"""

shape # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
device # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 48") # noqa: EM101
device, ns = device_namespace(device)
return _box(array, ns.zeros(shape, dtype=dtype))


def zeros_like(
Expand Down
31 changes: 22 additions & 9 deletions src/ragged/_spec_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ._spec_array_object import array
from ._typing import Dtype

_type = type


def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
"""
Expand Down Expand Up @@ -56,9 +58,7 @@ def can_cast(from_: Dtype | array, to: Dtype, /) -> bool:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html
"""

from_ # noqa: B018, pylint: disable=W0104
to # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 51") # noqa: EM101
return bool(np.can_cast(from_, to))


@dataclass
Expand Down Expand Up @@ -114,8 +114,16 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html
"""

type # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 52") # noqa: EM101
if not isinstance(type, np.dtype):
if not isinstance(type, _type) and hasattr(type, "dtype"):
out = np.finfo(type.dtype)
else:
out = np.finfo(np.dtype(type))
else:
out = np.finfo(type)
return finfo_object(
out.bits, out.eps, out.max, out.min, out.smallest_normal, out.dtype
)


@dataclass
Expand Down Expand Up @@ -155,8 +163,14 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html
"""

type # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 53") # noqa: EM101
if not isinstance(type, np.dtype):
if not isinstance(type, _type) and hasattr(type, "dtype"):
out = np.iinfo(type.dtype)
else:
out = np.iinfo(np.dtype(type))
else:
out = np.iinfo(type)
return iinfo_object(out.bits, out.max, out.min, out.dtype)


def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool:
Expand Down Expand Up @@ -218,5 +232,4 @@ def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html
"""

arrays_and_dtypes # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 55") # noqa: EM101
return np.result_type(*arrays_and_dtypes)
108 changes: 108 additions & 0 deletions tests/test_spec_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,21 @@

from __future__ import annotations

import numpy as np
import pytest

import ragged

devices = ["cpu"]
ns = {"cpu": np}
try:
import cupy as cp

devices.append("cuda")
ns["cuda"] = cp
except ModuleNotFoundError:
cp = None


def test_existence():
assert ragged.arange is not None
Expand All @@ -26,3 +39,98 @@ def test_existence():
assert ragged.triu is not None
assert ragged.zeros is not None
assert ragged.zeros_like is not None


@pytest.mark.parametrize("device", devices)
def test_arange(device):
a = ragged.arange(5, 10, 2, device=device)
assert a.tolist() == [5, 7, 9]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_empty(device):
a = ragged.empty((2, 3, 5), device=device)
assert a.shape == (2, 3, 5)
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_empty_ndim0(device):
a = ragged.empty((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_eye(device):
a = ragged.eye(3, 5, k=1, device=device)
assert a.tolist() == [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.skipif(
not hasattr(np, "from_dlpack"), reason=f"np.from_dlpack not in {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_from_dlpack(device):
a = ns[device].array([1, 2, 3, 4, 5])
b = ragged.from_dlpack(a)
assert b.tolist() == [1, 2, 3, 4, 5]
assert isinstance(b._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_full(device):
a = ragged.full(5, 3, device=device)
assert a.tolist() == [3, 3, 3, 3, 3]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_full_ndim0(device):
a = ragged.full((), 3, device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 3
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_linspace(device):
a = ragged.linspace(5, 8, 5, device=device)
assert a.tolist() == [5, 5.75, 6.5, 7.25, 8]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_ones(device):
a = ragged.ones(5, device=device)
assert a.tolist() == [1, 1, 1, 1, 1]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_ones_ndim0(device):
a = ragged.ones((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 1
assert isinstance(a._impl, ns[device].ndarray)


@pytest.mark.parametrize("device", devices)
def test_zeros(device):
a = ragged.zeros(5, device=device)
assert a.tolist() == [0, 0, 0, 0, 0]
assert isinstance(a._impl.layout.data, ns[device].ndarray) # type: ignore[union-attr]


@pytest.mark.parametrize("device", devices)
def test_zeros_ndim0(device):
a = ragged.zeros((), device=device)
assert a.ndim == 0
assert a.shape == ()
assert a == 0
assert isinstance(a._impl, ns[device].ndarray)
Loading