diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 54a1c497d..a31a7fe4b 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed @@ -87,7 +87,8 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]: Raises: ValueError: If `idx` is an integer index that is out bounds for - the axis length. + the axis length or if the type of `idx` is not one of + `Ellipsis`, `int`, or `slice`. """ if idx is Ellipsis: return length @@ -95,6 +96,8 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]: if idx < -length or idx > length - 1: raise ValueError(f"Index {idx} out of bounds for axis of length {length}.") return None + if not isinstance(idx, slice): + raise ValueError(f"Index expression {idx} is of an unrecognized type.") start, stop, stride = idx.indices(length) if start > stop: start = stop @@ -112,19 +115,24 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: Shape of indexed/sliced array. Raises: - ValueError: If `idx` is longer than `shape`. + ValueError: If any element of `idx` is not one of `Ellipsis`, + `int`, `slice`, or ``None`` (`np.newaxis`), or if an integer + index is out bounds for the corresponding axis length. """ if not isinstance(idx, tuple): idx = (idx,) - if len(idx) > len(shape): - raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") idx_shape: List[Optional[int]] = list(shape) offset = 0 + newaxis = 0 for axis, ax_idx in enumerate(idx): + if ax_idx is None: + idx_shape.insert(axis, 1) + newaxis += 1 + continue if ax_idx is Ellipsis: offset = len(shape) - len(idx) continue - idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) + idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx) return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore diff --git a/scico/test/linop/test_func.py b/scico/test/linop/test_func.py index 2376ca2f5..a6741ba47 100644 --- a/scico/test/linop/test_func.py +++ b/scico/test/linop/test_func.py @@ -91,6 +91,8 @@ def slicetestobj(request): np.s_[1:, :-3], np.s_[1:, :, :3], np.s_[1:, ..., 2:], + np.s_[np.newaxis], + np.s_[:, np.newaxis], ] diff --git a/scico/test/numpy/test_numpy_util.py b/scico/test/numpy/test_numpy_util.py index faab01dc5..c781b56e1 100644 --- a/scico/test/numpy/test_numpy_util.py +++ b/scico/test/numpy/test_numpy_util.py @@ -99,6 +99,10 @@ def test_slice_length_other(length, slc): np.s_[..., 2:], np.s_[..., 2:, :], np.s_[1:, ..., 2:], + np.s_[np.newaxis], + np.s_[:, np.newaxis], + np.s_[np.newaxis, :, np.newaxis], + np.s_[np.newaxis, ..., 0:2, :], ), ) def test_indexed_shape(shape, slc):