Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #468 #470

Merged
merged 8 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• New functional ``functional.AnisotropicTVNorm`` with proximal operator
approximation.
• New functionals ``functional.AnisotropicTVNorm`` and
``functional.ProximalAverage`` with proximal operator approximations.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
to ``XRayTransform``.
• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20.



Expand Down
88 changes: 49 additions & 39 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

"""Functional base class."""

from typing import List, Optional, Union

import jax
# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import List, Optional, Union

import scico
from scico import numpy as snp
Expand Down Expand Up @@ -38,15 +41,15 @@ def __init__(self):
def __repr__(self):
return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})"""

def __mul__(self, other):
if snp.isscalar(other) or isinstance(other, jax.core.Tracer):
def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self, other)
return NotImplemented

def __rmul__(self, other):
def __rmul__(self, other: Union[float, int]) -> ScaledFunctional:
return self.__mul__(other)

def __add__(self, other):
def __add__(self, other: Functional) -> FunctionalSum:
if isinstance(other, Functional):
return FunctionalSum(self, other)
return NotImplemented
Expand Down Expand Up @@ -122,46 +125,29 @@ def grad(self, x: Union[Array, BlockArray]):
return self._grad(x)


class FunctionalSum(Functional):
r"""A sum of two functionals."""

def __repr__(self):
return (
"Sum of functionals of types "
+ str(type(self.functional1))
+ " and "
+ str(type(self.functional2))
)

def __init__(self, functional1: Functional, functional2: Functional):
self.functional1 = functional1
self.functional2 = functional2
self.has_eval = functional1.has_eval and functional2.has_eval
self.has_prox = False
super().__init__()

def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.functional1(x) + self.functional2(x)


class ScaledFunctional(Functional):
r"""A functional multiplied by a scalar."""

def __repr__(self):
return (
"Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})"
)

def __init__(self, functional: Functional, scale: float):
self.functional = functional
self.scale = scale
self.has_eval = functional.has_eval
self.has_prox = functional.has_prox
super().__init__()

def __repr__(self):
return (
"Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})"
)

def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.scale * self.functional(x)

def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self.functional, other * self.scale)
return NotImplemented

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
Expand All @@ -173,17 +159,17 @@ def prox(
proximal operator of the unscaled functional with the proximal
operator scaling consisting of the product of the two scaling
factors, i.e., for functional :math:`f` and scaling factors
:math:`\alpha` and :math:`\beta`, the proximal operator with scaling
parameter :math:`\alpha` of scaled functional :math:`\beta f` is
the proximal operator with scaling parameter :math:`\alpha \beta`
of functional :math:`f`,
:math:`\alpha` and :math:`\beta`, the proximal operator with
scaling parameter :math:`\alpha` of scaled functional
:math:`\beta f` is the proximal operator with scaling parameter
:math:`\alpha \beta` of functional :math:`f`,

.. math::
\prox_{\alpha (\beta f)}(\mb{v}) =
\prox_{(\alpha \beta) f}(\mb{v}) \;.

"""
return self.functional.prox(v, lam * self.scale)
return self.functional.prox(v, lam * self.scale, **kwargs)


class SeparableFunctional(Functional):
Expand Down Expand Up @@ -245,13 +231,37 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:

"""
if len(v.shape) == len(self.functional_list):
return snp.blockarray([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)])
return snp.blockarray(
[fi.prox(vi, lam, **kwargs) for fi, vi in zip(self.functional_list, v)]
)
raise ValueError(
f"Number of blocks in v, {len(v.shape)}, and length of functional_list, "
f"{len(self.functional_list)}, do not match."
)


class FunctionalSum(Functional):
r"""A sum of two functionals."""

def __init__(self, functional1: Functional, functional2: Functional):
self.functional1 = functional1
self.functional2 = functional2
self.has_eval = functional1.has_eval and functional2.has_eval
self.has_prox = False
super().__init__()

def __repr__(self):
return (
"Sum of functionals of types "
+ str(type(self.functional1))
+ " and "
+ str(type(self.functional2))
)

def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.functional1(x) + self.functional2(x)


class ZeroFunctional(Functional):
r"""Zero functional, :math:`f(\mb{x}) = 0 \in \mbb{R}` for any input."""

Expand Down
12 changes: 6 additions & 6 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax

import scico
import scico.numpy as snp
from scico import functional, linop, operator
from scico.numpy import Array, BlockArray
Expand Down Expand Up @@ -125,6 +126,7 @@ def prox(
@_loss_mul_div_wrapper
def __mul__(self, other):
new_loss = copy(self)
new_loss._grad = scico.grad(new_loss.__call__)
new_loss.set_scale(self.scale * other)
return new_loss

Expand All @@ -134,6 +136,7 @@ def __rmul__(self, other):
@_loss_mul_div_wrapper
def __truediv__(self, other):
new_loss = copy(self)
new_loss._grad = scico.grad(new_loss.__call__)
new_loss.set_scale(self.scale / other)
return new_loss

Expand Down Expand Up @@ -216,12 +219,9 @@ def prox(
ATWA = c * A.conj() * W * A # type: ignore
return lhs / (ATWA + 1.0)

# prox_{f}(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W
# x
# solution at:
#
# (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y
#
# prox_f(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W
# x
# with solution: (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y
W = self.W
A = self.A
𝛼 = self.scale
Expand Down
13 changes: 11 additions & 2 deletions scico/test/functional/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def test_squared_l2(self):
pf = prox_test(self.v, L_d, L_d.prox, 0.75)
pf = prox_test(self.v, L, L.prox, 0.75)

def test_squared_l2_grad(self):
La = loss.SquaredL2Loss(y=self.y)
Lb = loss.SquaredL2Loss(y=self.y, scale=5e0)
Lc = 1e1 * La
ga = La.grad(self.v)
gb = Lb.grad(self.v)
gc = Lc.grad(self.v)
np.testing.assert_allclose(1e1 * ga, gb)
np.testing.assert_allclose(gb, gc)

def test_weighted_squared_l2(self):
L = loss.SquaredL2Loss(y=self.y, A=self.Ao, W=self.W)
assert L.has_eval
Expand Down Expand Up @@ -119,7 +129,6 @@ def test_poisson(self):


class TestAbsLoss:

abs_loss = (
(loss.SquaredL2AbsLoss, snp.abs),
(loss.SquaredL2SquaredAbsLoss, lambda x: snp.abs(x) ** 2),
Expand Down Expand Up @@ -218,7 +227,7 @@ def test_cubic_root():
r = loss._dep_cubic_root(p, q)
err = snp.abs(r**3 + p * r + q)
assert err.max() < 2e-4
# Test
# Test loss of precision warning
p = snp.array(1e-4, dtype=snp.float32)
q = snp.array(1e1, dtype=snp.float32)
with pytest.warns(UserWarning):
Expand Down
12 changes: 12 additions & 0 deletions scico/test/functional/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,15 @@ def test_l21norm(axis):
prxana = (l2ana - 1.0) / l2ana * x
prxnum = F.prox(x, 1.0)
np.testing.assert_allclose(prxana, prxnum, rtol=1e-5)


def test_scalar_aggregation():
f = functional.L2Norm()
g = 2.0 * f
h = 5.0 * g
assert isinstance(g, functional.ScaledFunctional)
assert isinstance(g.functional, functional.L2Norm)
assert g.scale == 2.0
assert isinstance(h, functional.ScaledFunctional)
assert isinstance(h.functional, functional.L2Norm)
assert h.scale == 10.0