From 3b92eefb2bafe31a41d620f11fcf4c6060c1ec0e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 28 Feb 2024 14:26:50 -0700 Subject: [PATCH] Clean up initialization --- scico/functional/_tvnorm.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index 930b71e31..386843b60 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -78,10 +78,24 @@ def __init__( if input_shape is not None: if ndims is None: ndims = len(input_shape) + self.G = self._call_operator(ndims, input_shape, input_dtype) self.WP, self.CWT = self._prox_operators(ndims, input_shape, input_dtype) + def _call_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator: + """Construct operator required by __call__ method.""" + axes = tuple(range(len(input_shape) - ndims, len(input_shape))) + G = FiniteDifference( + input_shape, + input_dtype=input_dtype, + axes=axes, + circular=self.circular, + append=None if self.circular else 0, + jit=True, + ) + return G + def __call__(self, x: Array) -> float: - r"""Compute the TV norm of an array. + """Compute the TV norm of an array. Args: x: Array for which the TV norm should be computed. @@ -94,15 +108,7 @@ def __call__(self, x: Array) -> float: ndims = x.ndim else: ndims = self.ndims - axes = tuple(range(x.ndim - ndims, x.ndim)) - self.G = FiniteDifference( - x.shape, - input_dtype=x.dtype, - axes=axes, - circular=self.circular, - append=None if self.circular else 0, - jit=True, - ) + self.G = self._call_operator(ndims, x.shape, x.dtype) return self.norm(self.G @ x) @staticmethod