Skip to content

Commit

Permalink
feat: make array object usable (#33)
Browse files Browse the repository at this point in the history
* feat: make array object usable

* __getitem__ (with minimal testing, but it's well tested in Awkward Array)

* old NumPy doesn't have 'from_dlpack'

* __contains__, __len__, __iter__

* now the only TODOs in the ragged.array class are unimplemented linear algebra functions

* expand_dims

* squeeze

* take (that was the last one)
  • Loading branch information
jpivarski authored Jan 15, 2024
1 parent 1890589 commit b9ba1fe
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 25 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ ignore = [
"ISC001", # Conflicts with formatter
"RET505", # I like my if (return) elif (return) else (return) pattern
"PLR5501", # I like my if (return) elif (return) else (return) pattern
"RET506", # I like my if (raise) elif ... else ... pattern
]
isort.required-imports = ["from __future__ import annotations"]
# Uncomment if using a _compat.typing backport
Expand All @@ -170,6 +171,7 @@ messages_control.disable = [
"missing-class-docstring",
"missing-function-docstring",
"R1705", # I like my if (return) elif (return) else (return) pattern
"R1720", # I like my if (raise) elif ... else ... pattern
"R0801", # Different files can have similar lines; that's okay
"C0302", # I can have as many lines as I want; what's it with you?
]
121 changes: 108 additions & 13 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from __future__ import annotations

import copy as copy_lib
import enum
import numbers
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Union

import awkward as ak
Expand Down Expand Up @@ -81,6 +83,10 @@ class array: # pylint: disable=C0103
"""
Ragged array class and constructor.
In addition to satisfying the Array API, a `ragged.array` is a `Collection`,
meaning that it is `Sized`, `Iterable`, and a `Container`. The `len` and
`iter` functions, as well as the `x in array` syntax all work.
https://data-apis.org/array-api/latest/API_specification/array_object.html
"""

Expand All @@ -93,7 +99,9 @@ class array: # pylint: disable=C0103
_device: Device

@classmethod
def _new(cls, impl: ak.Array, shape: Shape, dtype: Dtype, device: Device) -> array:
def _new(
cls, impl: ak.Array | SupportsDLPack, shape: Shape, dtype: Dtype, device: Device
) -> array:
"""
Simple/fast array constructor for internal code.
"""
Expand Down Expand Up @@ -225,8 +233,8 @@ def __init__(
else:
self._device = "cuda"

if copy is not None:
raise NotImplementedError("TODO 1") # noqa: EM101
if copy and isinstance(self._impl, ak.Array):
self._impl = copy_lib.deepcopy(self._impl)

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -258,6 +266,49 @@ def __repr__(self) -> str:
)
return f"ragged.array([\n {prep}\n])"

# Typical properties and methods for an array, but not part of the Array API

def __contains__(self, other: bool | int | float | complex) -> bool:
if isinstance(self._impl, ak.Array):
flat = ak.flatten(self._impl, axis=None)
assert isinstance(flat.layout, NumpyArray) # pylint: disable=E1101
return other in flat.layout.data # pylint: disable=E1101
else:
return other in self._impl # type: ignore[operator]

def __len__(self) -> int:
if isinstance(self._impl, ak.Array):
return len(self._impl)
else:
msg = "len() of unsized object"
raise TypeError(msg)

def __iter__(self) -> Iterator[array]:
if isinstance(self._impl, ak.Array):
t = type(self)
sh = self._shape[1:]
dt = self._dtype
dev = self._device
if sh == ():
for x in self._impl:
yield t._new(x, (), dt, dev)
else:
for x in self._impl:
yield t._new(x, (len(x), *sh), dt, dev)
else:
msg = "iteration over a 0-d array"
raise TypeError(msg)

def item(self) -> bool | int | float | complex:
if self.size == 1:
if isinstance(self._impl, ak.Array):
return ak.flatten(self._impl, axis=None)[0].item() # type: ignore[no-any-return]
else:
return self._impl.item() # type: ignore[no-any-return,union-attr]
else:
msg = "can only convert an array of size 1 to a Python scalar"
raise ValueError(msg)

def tolist(
self,
) -> bool | int | float | complex | NestedSequence[bool | int | float | complex]:
Expand Down Expand Up @@ -460,7 +511,11 @@ def __dlpack__(self, *, stream: None | int | Any = None) -> PyCapsule:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
"""

raise NotImplementedError("TODO 4") # noqa: EM101
buf = self._impl
if isinstance(buf, ak.Array):
buf = ak.to_numpy(buf) if ak.backend(buf) == "cpu" else ak.to_cupy(buf)

return buf.__dlpack__(stream=stream) # type: ignore[arg-type]

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
"""
Expand All @@ -472,7 +527,15 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html
"""

raise NotImplementedError("TODO 10") # noqa: EM101
buf = self._impl
if isinstance(buf, ak.Array):
buf = buf.layout
while isinstance(buf, (ListArray, ListOffsetArray, RegularArray)):
buf = buf.content
assert isinstance(buf, NumpyArray)
buf = buf.data

return buf.__dlpack_device__()

def __eq__(self, other: int | float | bool | array, /) -> array: # type: ignore[override]
"""
Expand Down Expand Up @@ -541,7 +604,22 @@ def __getitem__(self, key: GetSliceKey, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__getitem__.html
"""

raise NotImplementedError("TODO 15") # noqa: EM101
if isinstance(key, tuple):
for item in key:
if not isinstance(
item, (numbers.Integral, slice, type(...), type(None))
):
msg = f"ragged.array sliced as arr[item1, item2, ...] can only have int, slice, ellipsis, None (np.newaxis) as items, not {item!r}"
raise TypeError(msg)
elif not isinstance(
key, (numbers.Integral, slice, type(...), type(None), array)
):
key = array(key) # attempt to cast unknown key type as ragged.array

if isinstance(key, array):
key = key._impl # type: ignore[assignment]

return _box(type(self), self._impl[key]) # type: ignore[index]

def __gt__(self, other: int | float | array, /) -> array:
"""
Expand Down Expand Up @@ -790,7 +868,8 @@ def __setitem__(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__setitem__.html
"""

raise NotImplementedError("TODO 31") # noqa: EM101
msg = "ragged.array is an immutable type; its values cannot be assigned to"
raise TypeError(msg)

def __sub__(self, other: int | float | array, /) -> array:
"""
Expand Down Expand Up @@ -853,26 +932,42 @@ def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> ar
main memory; if `"cuda"`, the array is backed by CuPy and
resides in CUDA global memory.
stream: CuPy Stream object (https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html)
for `device="cuda"`.
for `device="cuda"`. Ignored if output `device` is `"cpu"`. If
this argument is an integer, it is interpreted as the pointer
address of a `cudaStream_t` object.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html
"""

if isinstance(stream, numbers.Integral):
cp = _import.cupy()
stream = cp.cuda.ExternalStream(stream)

if stream is not None:
t = type(stream)
if not t.__module__.startswith("cupy.") or t.__name__ != "Stream":
msg = f"stream object must be a cupy.cuda.Stream, not {stream!r}"
raise TypeError(msg)

if isinstance(self._impl, ak.Array):
if device != ak.backend(self._impl):
if stream is not None:
raise NotImplementedError("TODO 124") # noqa: EM101
impl = ak.to_backend(self._impl, device)
with stream:
impl = ak.to_backend(self._impl, device)
else:
impl = ak.to_backend(self._impl, device)
else:
impl = self._impl

elif isinstance(self._impl, np.ndarray):
# self._impl is a NumPy 0-dimensional array
if device == "cuda":
if stream is not None:
raise NotImplementedError("TODO 125") # noqa: EM101
cp = _import.cupy()
impl = cp.array(self._impl)
if stream is not None:
with stream:
impl = cp.array(self._impl)
else:
impl = cp.array(self._impl)
else:
impl = self._impl

Expand Down
33 changes: 28 additions & 5 deletions src/ragged/_spec_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from __future__ import annotations

from ._spec_array_object import array
import awkward as ak
import numpy as np

from ._spec_array_object import _box, array


def take(x: array, indices: array, /, *, axis: None | int = None) -> array:
Expand Down Expand Up @@ -37,7 +40,27 @@ def take(x: array, indices: array, /, *, axis: None | int = None) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html
"""

x # noqa: B018, pylint: disable=W0104
indices # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 109") # noqa: EM101
if axis is None:
if x.ndim <= 1:
axis = 0
else:
msg = f"for an {x.ndim}-dimensional array (greater than 1-dimensional), the 'axis' argument is required"
raise TypeError(msg)

original_axis = axis
if axis < 0:
axis += x.ndim + 1
if not 0 <= axis < x.ndim:
msg = f"axis {original_axis} is out of bounds for array of dimension {x.ndim}"
raise ak.errors.AxisError(msg)

toslice = x._impl # pylint: disable=W0212
if not isinstance(toslice, ak.Array):
toslice = ak.Array(toslice[np.newaxis]) # type: ignore[index]

if not isinstance(indices, array):
indices = array(indices) # type: ignore[unreachable]
indexarray = indices._impl # pylint: disable=W0212

slicer = (slice(None),) * axis + (indexarray,)
return _box(type(x), toslice[slicer])
65 changes: 59 additions & 6 deletions src/ragged/_spec_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from __future__ import annotations

import numbers

import awkward as ak
import numpy as np

from ._spec_array_object import _box, _unbox, array

Expand Down Expand Up @@ -123,12 +126,28 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
Returns:
An expanded output array having the same data type as `x`.
This is the opposite of `ragged.squeeze`.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html
"""

x # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 117") # noqa: EM101
original_axis = axis
if axis < 0:
axis += x.ndim + 1
if not 0 <= axis <= x.ndim:
msg = (
f"axis {original_axis} is out of bounds for array of dimension {x.ndim + 1}"
)
raise ak.errors.AxisError(msg)

slicer = (slice(None),) * axis + (np.newaxis,)
shape = x.shape[:axis] + (1,) + x.shape[axis:]

out = x._impl[slicer] # type: ignore[index] # pylint: disable=W0212
if not isinstance(out, ak.Array):
out = ak.Array(out)

return x._new(out, shape, x.dtype, x.device) # pylint: disable=W0212


def flip(x: array, /, *, axis: None | int | tuple[int, ...] = None) -> array:
Expand Down Expand Up @@ -257,12 +276,46 @@ def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array:
Returns:
An output array having the same data type and elements as `x`.
This is the opposite of `ragged.expand_dims`.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html
"""

x # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 122") # noqa: EM101
if isinstance(axis, numbers.Integral):
axis = (axis,) # type: ignore[assignment]

posaxis = []
for axisitem in axis: # type: ignore[union-attr]
posaxisitem = axisitem + x.ndim if axisitem < 0 else axisitem
if not 0 <= posaxisitem < x.ndim and not posaxisitem == x.ndim == 0:
msg = f"axis {axisitem} is out of bounds for array of dimension {x.ndim}"
raise ak.errors.AxisError(msg)
posaxis.append(posaxisitem)

if not isinstance(x._impl, ak.Array): # pylint: disable=W0212
return x._new(x._impl, x._shape, x._dtype, x._device) # pylint: disable=W0212

out = x._impl # pylint: disable=W0212
shape = list(x.shape)
for i, shapeitem in reversed(list(enumerate(x.shape))):
if i in posaxis:
if shapeitem is None:
if not np.all(ak.num(out, axis=i) == 1):
msg = "cannot select an axis to squeeze out which has size not equal to one"
raise ValueError(msg)
else:
out = out[(slice(None),) * i + (0,)]
del shape[i]

elif shapeitem == 1:
out = out[(slice(None),) * i + (0,)]
del shape[i]

else:
msg = "cannot select an axis to squeeze out which has size not equal to one"
raise ValueError(msg)

return x._new(out, tuple(shape), x.dtype, x.device) # pylint: disable=W0212


def stack(arrays: tuple[array, ...] | list[array], /, *, axis: int = 0) -> array:
Expand Down
Loading

0 comments on commit b9ba1fe

Please sign in to comment.