diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index b48bb7dec..930b71e31 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -123,7 +123,7 @@ def _center(idx: int, ndims: int) -> Tuple: """ return (0,) * idx + (1,) + (0,) * (ndims - idx - 1) - def _haar_operator(self, ndims, input_shape, input_dtype): + def _haar_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator: """Construct single-level shift-invariant Haar transform.""" h0 = self.h0.astype(input_dtype) h1 = self.h1.astype(input_dtype) @@ -142,7 +142,9 @@ def _haar_operator(self, ndims, input_shape, input_dtype): # single-level shift-invariant Haar transform return VerticalStack([L, H], jit=True) - def _prox_operators(self, ndims, input_shape, input_dtype): + def _prox_operators( + self, ndims: int, input_shape: Shape, input_dtype: DType + ) -> Tuple[LinearOperator, LinearOperator]: """Construct operators required by prox method.""" w_input_shape = ( # circular boundary: shape of input array @@ -200,7 +202,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: # Apply shrinkage to highpass component of shift-invariant Haar transform # of padded input (or to non-boundary region thereof for non-circular # boundary conditions). - WPv = self.WP(v) + WPv: Array = self.WP(v) WPv = WPv.at[slce].set(self.norm.prox(WPv[slce], snp.sqrt(2) * K * lam)) u = (1.0 / K) * self.CWT(WPv)