diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 930b71e31..386843b60 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -78,10 +78,24 @@ def __init__( if input_shape is not None: if ndims is None: ndims = len(input_shape) + self.G = self._call_operator(ndims, input_shape, input_dtype) self.WP, self.CWT = self._prox_operators(ndims, input_shape, input_dtype) + def _call_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator: + """Construct operator required by __call__ method.""" + axes = tuple(range(len(input_shape) - ndims, len(input_shape))) + G = FiniteDifference( + input_shape, + input_dtype=input_dtype, + axes=axes, + circular=self.circular, + append=None if self.circular else 0, + jit=True, + ) + return G + def __call__(self, x: Array) -> float: - r"""Compute the TV norm of an array. + """Compute the TV norm of an array. Args: x: Array for which the TV norm should be computed. @@ -94,15 +108,7 @@ def __call__(self, x: Array) -> float: ndims = x.ndim else: ndims = self.ndims - axes = tuple(range(x.ndim - ndims, x.ndim)) - self.G = FiniteDifference( - x.shape, - input_dtype=x.dtype, - axes=axes, - circular=self.circular, - append=None if self.circular else 0, - jit=True, - ) + self.G = self._call_operator(ndims, x.shape, x.dtype) return self.norm(self.G @ x) @staticmethod