Skip to content

Commit

Permalink
Resolve typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 28, 2024
1 parent 59949bc commit 697872e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _center(idx: int, ndims: int) -> Tuple:
"""
return (0,) * idx + (1,) + (0,) * (ndims - idx - 1)

def _haar_operator(self, ndims, input_shape, input_dtype):
def _haar_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator:
"""Construct single-level shift-invariant Haar transform."""
h0 = self.h0.astype(input_dtype)
h1 = self.h1.astype(input_dtype)
Expand All @@ -142,7 +142,9 @@ def _haar_operator(self, ndims, input_shape, input_dtype):
# single-level shift-invariant Haar transform
return VerticalStack([L, H], jit=True)

def _prox_operators(self, ndims, input_shape, input_dtype):
def _prox_operators(
self, ndims: int, input_shape: Shape, input_dtype: DType
) -> Tuple[LinearOperator, LinearOperator]:
"""Construct operators required by prox method."""
w_input_shape = (
# circular boundary: shape of input array
Expand Down Expand Up @@ -200,7 +202,7 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
# Apply shrinkage to highpass component of shift-invariant Haar transform
# of padded input (or to non-boundary region thereof for non-circular
# boundary conditions).
WPv = self.WP(v)
WPv: Array = self.WP(v)
WPv = WPv.at[slce].set(self.norm.prox(WPv[slce], snp.sqrt(2) * K * lam))
u = (1.0 / K) * self.CWT(WPv)

Expand Down

0 comments on commit 697872e

Please sign in to comment.