From 7b21f3b5948d5e50a6a3ea056070c8071cc98d6d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 10 Nov 2023 15:26:51 -0700 Subject: [PATCH 1/7] Update change log --- CHANGES.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 265b9dd38..f7a33d44d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,15 +6,15 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- -• New functional ``functional.AnisotropicTVNorm`` with proximal operator - approximation. +• New functionals ``functional.AnisotropicTVNorm`` and + ``functional.ProximalAverage`` with proximal operator approximations. • New integrated Radon/X-ray transform ``linop.XRayTransform``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes to ``XRayTransform``. • Rename ``AbelProjector`` to ``AbelTransform``. • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20. From 6faf02cf28253b5b4a78af72e14157ebbad8f39b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 14 Nov 2023 07:02:16 -0700 Subject: [PATCH 2/7] Resolve #468 and add corresponding test --- scico/loss.py | 3 +++ scico/test/functional/test_loss.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 1a5970f1a..41e4b3131 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -15,6 +15,7 @@ import jax +import scico import scico.numpy as snp from scico import functional, linop, operator from scico.numpy import Array, BlockArray @@ -125,6 +126,7 @@ def prox( @_loss_mul_div_wrapper def __mul__(self, other): new_loss = copy(self) + new_loss._grad = scico.grad(new_loss.__call__) new_loss.set_scale(self.scale * other) return new_loss @@ -134,6 +136,7 @@ def __rmul__(self, other): @_loss_mul_div_wrapper def __truediv__(self, other): new_loss = copy(self) + new_loss._grad = scico.grad(new_loss.__call__) new_loss.set_scale(self.scale / other) return new_loss diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index 420872aa3..f5b4a3dd9 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -84,6 +84,16 @@ def test_squared_l2(self): pf = prox_test(self.v, L_d, L_d.prox, 0.75) pf = prox_test(self.v, L, L.prox, 0.75) + def test_squared_l2_grad(self): + La = loss.SquaredL2Loss(y=self.y) + Lb = loss.SquaredL2Loss(y=self.y, scale=5e0) + Lc = 1e1 * La + ga = La.grad(self.v) + gb = Lb.grad(self.v) + gc = Lc.grad(self.v) + np.testing.assert_allclose(1e1 * ga, gb) + np.testing.assert_allclose(gb, gc) + def test_weighted_squared_l2(self): L = loss.SquaredL2Loss(y=self.y, A=self.Ao, W=self.W) assert L.has_eval @@ -119,7 +129,6 @@ def test_poisson(self): class TestAbsLoss: - abs_loss = ( (loss.SquaredL2AbsLoss, snp.abs), (loss.SquaredL2SquaredAbsLoss, lambda x: snp.abs(x) ** 2), @@ -218,7 +227,7 @@ def test_cubic_root(): r = loss._dep_cubic_root(p, q) err = snp.abs(r**3 + p * r + q) assert err.max() < 2e-4 - # Test + # Test loss of precision warning p = snp.array(1e-4, dtype=snp.float32) q = snp.array(1e1, dtype=snp.float32) with pytest.warns(UserWarning): From eb833d95cef50c7b05fae58ec00253d13e847691 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 14 Nov 2023 07:24:52 -0700 Subject: [PATCH 3/7] Shorten comment --- scico/loss.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 41e4b3131..0d09db4f5 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -219,12 +219,9 @@ def prox( ATWA = c * A.conj() * W * A # type: ignore return lhs / (ATWA + 1.0) - # prox_{f}(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W - # x - # solution at: - # - # (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y - # + # prox_f(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W + # x + # with solution: (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y W = self.W A = self.A 𝛼 = self.scale From bb11202e11134ac2da88f48174f88215ae2a2bc1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 14 Nov 2023 07:28:54 -0700 Subject: [PATCH 4/7] Resolve some oversights in prox definitions --- scico/functional/_functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 3216dac40..a59f77b47 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -183,7 +183,7 @@ def prox( \prox_{(\alpha \beta) f}(\mb{v}) \;. """ - return self.functional.prox(v, lam * self.scale) + return self.functional.prox(v, lam * self.scale, **kwargs) class SeparableFunctional(Functional): @@ -245,7 +245,9 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: """ if len(v.shape) == len(self.functional_list): - return snp.blockarray([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) + return snp.blockarray( + [fi.prox(vi, lam, **kwargs) for fi, vi in zip(self.functional_list, v)] + ) raise ValueError( f"Number of blocks in v, {len(v.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match." From 03f8635088883e50e4c6e2a31bc3b8a510ed3d1a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 14 Nov 2023 07:33:18 -0700 Subject: [PATCH 5/7] Minor edit --- scico/functional/_functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index a59f77b47..f2a897625 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -173,10 +173,10 @@ def prox( proximal operator of the unscaled functional with the proximal operator scaling consisting of the product of the two scaling factors, i.e., for functional :math:`f` and scaling factors - :math:`\alpha` and :math:`\beta`, the proximal operator with scaling - parameter :math:`\alpha` of scaled functional :math:`\beta f` is - the proximal operator with scaling parameter :math:`\alpha \beta` - of functional :math:`f`, + :math:`\alpha` and :math:`\beta`, the proximal operator with + scaling parameter :math:`\alpha` of scaled functional + :math:`\beta f` is the proximal operator with scaling parameter + :math:`\alpha \beta` of functional :math:`f`, .. math:: \prox_{\alpha (\beta f)}(\mb{v}) = From 21f56687fa3c0999f443906413245b69cd70476f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 14 Nov 2023 08:02:59 -0700 Subject: [PATCH 6/7] Avoid chaining of ScaledFunctional and some code re-organization --- scico/functional/_functional.py | 70 +++++++++++++++++------------- scico/test/functional/test_misc.py | 12 +++++ 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index f2a897625..60b22064e 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -7,6 +7,11 @@ """Functional base class.""" + +# Needed to annotate a class method that returns the encapsulating class; +# see https://www.python.org/dev/peps/pep-0563/ +from __future__ import annotations + from typing import List, Optional, Union import jax @@ -38,15 +43,15 @@ def __init__(self): def __repr__(self): return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})""" - def __mul__(self, other): + def __mul__(self, other: Union[float, int]) -> ScaledFunctional: if snp.isscalar(other) or isinstance(other, jax.core.Tracer): return ScaledFunctional(self, other) return NotImplemented - def __rmul__(self, other): + def __rmul__(self, other: Union[float, int]) -> ScaledFunctional: return self.__mul__(other) - def __add__(self, other): + def __add__(self, other: Functional) -> FunctionalSum: if isinstance(other, Functional): return FunctionalSum(self, other) return NotImplemented @@ -122,36 +127,9 @@ def grad(self, x: Union[Array, BlockArray]): return self._grad(x) -class FunctionalSum(Functional): - r"""A sum of two functionals.""" - - def __repr__(self): - return ( - "Sum of functionals of types " - + str(type(self.functional1)) - + " and " - + str(type(self.functional2)) - ) - - def __init__(self, functional1: Functional, functional2: Functional): - self.functional1 = functional1 - self.functional2 = functional2 - self.has_eval = functional1.has_eval and functional2.has_eval - self.has_prox = False - super().__init__() - - def __call__(self, x: Union[Array, BlockArray]) -> float: - return self.functional1(x) + self.functional2(x) - - class ScaledFunctional(Functional): r"""A functional multiplied by a scalar.""" - def __repr__(self): - return ( - "Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})" - ) - def __init__(self, functional: Functional, scale: float): self.functional = functional self.scale = scale @@ -159,9 +137,19 @@ def __init__(self, functional: Functional, scale: float): self.has_prox = functional.has_prox super().__init__() + def __repr__(self): + return ( + "Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})" + ) + def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * self.functional(x) + def __mul__(self, other: Union[float, int]) -> ScaledFunctional: + if snp.isscalar(other) or isinstance(other, jax.core.Tracer): + return ScaledFunctional(self.functional, other * self.scale) + return NotImplemented + def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: @@ -254,6 +242,28 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: ) +class FunctionalSum(Functional): + r"""A sum of two functionals.""" + + def __init__(self, functional1: Functional, functional2: Functional): + self.functional1 = functional1 + self.functional2 = functional2 + self.has_eval = functional1.has_eval and functional2.has_eval + self.has_prox = False + super().__init__() + + def __repr__(self): + return ( + "Sum of functionals of types " + + str(type(self.functional1)) + + " and " + + str(type(self.functional2)) + ) + + def __call__(self, x: Union[Array, BlockArray]) -> float: + return self.functional1(x) + self.functional2(x) + + class ZeroFunctional(Functional): r"""Zero functional, :math:`f(\mb{x}) = 0 \in \mbb{R}` for any input.""" diff --git a/scico/test/functional/test_misc.py b/scico/test/functional/test_misc.py index 3c8ac98b4..69a8ceec8 100644 --- a/scico/test/functional/test_misc.py +++ b/scico/test/functional/test_misc.py @@ -130,3 +130,15 @@ def test_l21norm(axis): prxana = (l2ana - 1.0) / l2ana * x prxnum = F.prox(x, 1.0) np.testing.assert_allclose(prxana, prxnum, rtol=1e-5) + + +def test_scalar_aggregation(): + f = functional.L2Norm() + g = 2.0 * f + h = 5.0 * g + assert isinstance(g, functional.ScaledFunctional) + assert isinstance(g.functional, functional.L2Norm) + assert g.scale == 2.0 + assert isinstance(h, functional.ScaledFunctional) + assert isinstance(h.functional, functional.L2Norm) + assert h.scale == 10.0 From 124b39197be2e8866fe4d600e036abf06843ca9e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 15 Nov 2023 10:35:00 -0700 Subject: [PATCH 7/7] Address review comment --- scico/functional/_functional.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 60b22064e..256791a73 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -14,8 +14,6 @@ from typing import List, Optional, Union -import jax - import scico from scico import numpy as snp from scico.numpy import Array, BlockArray @@ -44,7 +42,7 @@ def __repr__(self): return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})""" def __mul__(self, other: Union[float, int]) -> ScaledFunctional: - if snp.isscalar(other) or isinstance(other, jax.core.Tracer): + if snp.util.is_scalar_equiv(other): return ScaledFunctional(self, other) return NotImplemented @@ -146,7 +144,7 @@ def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * self.functional(x) def __mul__(self, other: Union[float, int]) -> ScaledFunctional: - if snp.isscalar(other) or isinstance(other, jax.core.Tracer): + if snp.util.is_scalar_equiv(other): return ScaledFunctional(self.functional, other * self.scale) return NotImplemented