Skip to content

Commit

Permalink
Avoid jax tracer errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 26, 2024
1 parent 37a2fef commit 59949bc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/scripts/denoise_approx_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
Denoise with isotropic total variation, solved using an approximation of
the TV norm proximal operator.
"""
h = λ_iso * functional.IsotropicTVNorm()
h = λ_iso * functional.IsotropicTVNorm(circular=True, input_shape=y.shape)
solver = AcceleratedPGM(
f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={"display": True, "period": 50}
)
Expand All @@ -102,7 +102,7 @@
Denoise with anisotropic total variation, solved using an approximation
of the TV norm proximal operator.
"""
h = λ_aniso * functional.AnisotropicTVNorm()
h = λ_aniso * functional.AnisotropicTVNorm(circular=True, input_shape=y.shape)
solver = AcceleratedPGM(
f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={"display": True, "period": 50}
)
Expand Down
72 changes: 66 additions & 6 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
VerticalStack,
)
from scico.numpy import Array
from scico.typing import DType, Shape

from ._functional import Functional
from ._norm import L1Norm, L21Norm
Expand All @@ -35,14 +36,36 @@ class TVNorm(Functional):
has_eval = True
has_prox = True

def __init__(self, norm: Functional, circular: bool = True, ndims: Optional[int] = None):
def __init__(
self,
norm: Functional,
circular: bool = True,
ndims: Optional[int] = None,
input_shape: Optional[Shape] = None,
input_dtype: DType = snp.float32,
):
"""
While initializers for :class:`.Functional` objects typically do
not take `input_shape` and `input_dtype` parameters, they are
included here because methods :meth:`__call__` and :meth:`prox`
require instantiation of some :class:`.LinearOperator` objects,
which do take these parameters. If these parameters are not
provided on intialization of a :class:`TVNorm` object, then
creation of the required :class:`.LinearOperator` objects is
deferred until these methods are called, which can result in
`JAX tracer <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables>`__
errors when they are components of a jitted function.
Args:
norm: Norm functional from which the TV norm is composed.
circular: Flag indicating use of circular boundary conditions.
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
input_shape: Shape of input arrays of :meth:`__call__` and
:meth:`prox`.
input_dtype: `dtype` of input arrays of :meth:`__call__` and
:meth:`prox`.
"""
self.norm = norm
self.circular = circular
Expand All @@ -52,6 +75,11 @@ def __init__(self, norm: Functional, circular: bool = True, ndims: Optional[int]
self.G: Optional[LinearOperator] = None
self.WP: Optional[LinearOperator] = None

if input_shape is not None:
if ndims is None:
ndims = len(input_shape)
self.WP, self.CWT = self._prox_operators(ndims, input_shape, input_dtype)

def __call__(self, x: Array) -> float:
r"""Compute the TV norm of an array.
Expand Down Expand Up @@ -171,7 +199,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
)
# Apply shrinkage to highpass component of shift-invariant Haar transform
# of padded input (or to non-boundary region thereof for non-circular
# boundary conditions)
# boundary conditions).
WPv = self.WP(v)
WPv = WPv.at[slce].set(self.norm.prox(WPv[slce], snp.sqrt(2) * K * lam))
u = (1.0 / K) * self.CWT(WPv)
Expand Down Expand Up @@ -210,15 +238,31 @@ class AnisotropicTVNorm(TVNorm):
in the `rho_list` algorithm parameter.
"""

def __init__(self, circular: bool = False, ndims: Optional[int] = None):
def __init__(
self,
circular: bool = False,
ndims: Optional[int] = None,
input_shape: Optional[Shape] = None,
input_dtype: DType = snp.float32,
):
"""
Args:
circular: Flag indicating use of circular boundary conditions.
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and
:meth:`~.TVNorm.prox`.
input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and
:meth:`~.TVNorm.prox`.
"""
super().__init__(L1Norm(), circular=circular, ndims=ndims)
super().__init__(
L1Norm(),
circular=circular,
ndims=ndims,
input_shape=input_shape,
input_dtype=input_dtype,
)


class IsotropicTVNorm(TVNorm):
Expand Down Expand Up @@ -252,12 +296,28 @@ class IsotropicTVNorm(TVNorm):
in the `rho_list` algorithm parameter.
"""

def __init__(self, circular: bool = False, ndims: Optional[int] = None):
def __init__(
self,
circular: bool = False,
ndims: Optional[int] = None,
input_shape: Optional[Shape] = None,
input_dtype: DType = snp.float32,
):
r"""
Args:
circular: Flag indicating use of circular boundary conditions.
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and
:meth:`~.TVNorm.prox`.
input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and
:meth:`~.TVNorm.prox`.
"""
super().__init__(L21Norm(), circular=circular, ndims=ndims)
super().__init__(
L21Norm(),
circular=circular,
ndims=ndims,
input_shape=input_shape,
input_dtype=input_dtype,
)
10 changes: 5 additions & 5 deletions scico/test/functional/test_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_aniso_1d(circular):
)
x_tvdn = solver.solve()

h = λ * functional.AnisotropicTVNorm(circular=circular)
h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape)
solver = AcceleratedPGM(f=f, g=h, L0=5e2, x0=y, maxiter=100)
x_approx = solver.solve()

Expand Down Expand Up @@ -88,9 +88,9 @@ def test_2d(self, tvtype, circular):
x_tvdn = solver.solve()

if tvtype == "aniso":
h = λ * functional.AnisotropicTVNorm(circular=circular)
h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape)
else:
h = λ * functional.IsotropicTVNorm(circular=circular)
h = λ * functional.IsotropicTVNorm(circular=circular, input_shape=y.shape)

solver = AcceleratedPGM(
f=f,
Expand Down Expand Up @@ -148,9 +148,9 @@ def test_3d(self, tvtype, circular):
x_tvdn = solver.solve()

if tvtype == "aniso":
h = λ * functional.AnisotropicTVNorm(circular=circular, ndims=2)
h = λ * functional.AnisotropicTVNorm(circular=circular, ndims=2, input_shape=y.shape)
else:
h = λ * functional.IsotropicTVNorm(circular=circular, ndims=2)
h = λ * functional.IsotropicTVNorm(circular=circular, ndims=2, input_shape=y.shape)

solver = AcceleratedPGM(
f=f,
Expand Down

0 comments on commit 59949bc

Please sign in to comment.