Skip to content

Commit

Permalink
Add some error checking, ensure ValueError raised as expected by test…
Browse files Browse the repository at this point in the history
…_misc.py for functionals
  • Loading branch information
bwohlberg committed Mar 4, 2024
1 parent 97c7693 commit e59701d
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,7 @@ def __init__(self, delta: float = 1.0, separable: bool = True):

def _call_sep(self, x: Union[Array, BlockArray]) -> float:
xabs = snp.abs(x)
hx = snp.where(
xabs <= self.delta, 0.5 * xabs**2, self.delta * (xabs - (self.delta / 2.0))
)
hx = snp.where(xabs <= self.delta, 0.5 * xabs**2, self.delta * (xabs - (self.delta / 2.0)))
return snp.sum(hx)

def _call_nonsep(self, x: Union[Array, BlockArray]) -> float:
Expand Down Expand Up @@ -481,6 +479,8 @@ class NuclearNorm(Functional):
has_prox = True

def __call__(self, x: Union[Array, BlockArray]) -> float:
if x.ndim != 2:
raise ValueError("Input array must be two dimensional.")
return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False))

def prox(
Expand All @@ -492,12 +492,13 @@ def prox(
:cite:`cai-2010-singular`.
Args:
v: Input array :math:`\mb{v}`.
v: Input array :math:`\mb{v}`. Required to be two-dimensional.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""

if v.ndim != 2:
raise ValueError("Input array must be two dimensional.")
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV

0 comments on commit e59701d

Please sign in to comment.