Skip to content

Commit

Permalink
Clean up initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 28, 2024
1 parent 31e35eb commit 3b92eef
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,24 @@ def __init__(
if input_shape is not None:
if ndims is None:
ndims = len(input_shape)
self.G = self._call_operator(ndims, input_shape, input_dtype)
self.WP, self.CWT = self._prox_operators(ndims, input_shape, input_dtype)

def _call_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator:
"""Construct operator required by __call__ method."""
axes = tuple(range(len(input_shape) - ndims, len(input_shape)))
G = FiniteDifference(
input_shape,
input_dtype=input_dtype,
axes=axes,
circular=self.circular,
append=None if self.circular else 0,
jit=True,
)
return G

def __call__(self, x: Array) -> float:
r"""Compute the TV norm of an array.
"""Compute the TV norm of an array.
Args:
x: Array for which the TV norm should be computed.
Expand All @@ -94,15 +108,7 @@ def __call__(self, x: Array) -> float:
ndims = x.ndim
else:
ndims = self.ndims
axes = tuple(range(x.ndim - ndims, x.ndim))
self.G = FiniteDifference(
x.shape,
input_dtype=x.dtype,
axes=axes,
circular=self.circular,
append=None if self.circular else 0,
jit=True,
)
self.G = self._call_operator(ndims, x.shape, x.dtype)
return self.norm(self.G @ x)

@staticmethod
Expand Down

0 comments on commit 3b92eef

Please sign in to comment.