Skip to content

Commit

Permalink
Fix handling of BlockArray input
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 26, 2024
1 parent 19f8668 commit e6c174c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 31 deletions.
23 changes: 14 additions & 9 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,34 +194,37 @@ class L21Norm(Functional):
The norm generalizes to more dimensions by first computing the
:math:`\ell_2` norm along one or more (user-specified) axes,
followed by a sum over all remaining axes.
For `BlockArray` inputs, the :math:`\ell_2` norm follows the
reduction rules described in :class:`BlockArray`.
followed by a sum over all remaining axes. :class:`BlockArray` inputs
require parameter `l2_axis` to be ``None``, in which case the
:math:`\ell_2` norm is computed over each block.
A typical use case is computing the isotropic total variation norm.
"""

has_eval = True
has_prox = True

def __init__(self, l2_axis: Union[int, Tuple] = 0):
def __init__(self, l2_axis: Union[None, int, Tuple] = 0):
r"""
Args:
l2_axis: Axis/axes over which to take the l2 norm. Default: 0.
l2_axis: Axis/axes over which to take the l2 norm. Required
to be ``None`` for :class:`BlockArray` inputs to be
supported.
"""
self.l2_axis = l2_axis

@staticmethod
def _l2norm(
x: Union[Array, BlockArray], axis: Union[int, Tuple], keepdims: Optional[bool] = False
x: Union[Array, BlockArray], axis: Union[None, int, Tuple], keepdims: Optional[bool] = False
):
r"""Return the :math:`\ell_2` norm of an array."""
return snp.sqrt(snp.sum(snp.abs(x) ** 2, axis=axis, keepdims=keepdims))
return snp.sqrt((snp.abs(x) ** 2).sum(axis=axis, keepdims=keepdims))

def __call__(self, x: Union[Array, BlockArray]) -> float:
if isinstance(x, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer parameter l2_axis must be None for BlockArray input.")
l2 = L21Norm._l2norm(x, axis=self.l2_axis)
return snp.abs(l2).sum()
return snp.sum(snp.abs(l2))

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
Expand Down Expand Up @@ -249,6 +252,8 @@ def prox(
kwargs: Additional arguments that may be used by derived
classes.
"""
if isinstance(v, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer parameter l2_axis must be None for BlockArray input.")
length = L21Norm._l2norm(v, axis=self.l2_axis, keepdims=True)
direction = no_nan_divide(v, length)

Expand Down
22 changes: 0 additions & 22 deletions scico/test/functional/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,6 @@ def foo(c):
np.testing.assert_allclose(non_pmap, pmapped)


@pytest.mark.parametrize("axis", [0, 1, (0, 2)])
def test_l21norm(axis):
x = np.ones((3, 4, 5))
if isinstance(axis, int):
l2axis = (axis,)
else:
l2axis = axis
l2shape = [x.shape[k] for k in l2axis]
l1axis = tuple(set(range(len(x))) - set(l2axis))
l1shape = [x.shape[k] for k in l1axis]

l21ana = np.sqrt(np.prod(l2shape)) * np.prod(l1shape)
F = functional.L21Norm(l2_axis=axis)
l21num = F(x)
np.testing.assert_allclose(l21ana, l21num, rtol=1e-5)

l2ana = np.sqrt(np.prod(l2shape))
prxana = (l2ana - 1.0) / l2ana * x
prxnum = F.prox(x, 1.0)
np.testing.assert_allclose(prxana, prxnum, rtol=1e-5)


def test_scalar_aggregation():
f = functional.L2Norm()
g = 2.0 * f
Expand Down

0 comments on commit e6c174c

Please sign in to comment.