Skip to content

Commit

Permalink
Fix behaviour for non-default ndims parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 26, 2024
1 parent 203b3c9 commit 5ab605f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
23 changes: 16 additions & 7 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, x: Array) -> float:
ndims = x.ndim
else:
ndims = self.ndims
axes = tuple(range(ndims))
axes = tuple(range(x.ndim - ndims, x.ndim))
self.G = FiniteDifference(
x.shape,
input_dtype=x.dtype,
Expand Down Expand Up @@ -133,13 +133,18 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
ndims = self.ndims
K = 2 * ndims

w_input_shape = v.shape if self.circular else tuple([n + 1 for n in v.shape])
w_input_shape = (
v.shape
if self.circular
else v.shape[0 : (v.ndim - ndims)] + tuple([n + 1 for n in v.shape[-ndims:]])
)
if self.W is None or self.W.shape[1] != w_input_shape:
self.W = self._haar_operator(ndims, w_input_shape, v.dtype)
if not self.circular:
P = Pad(v.shape, pad_width=(((0, 1),) * ndims), mode="edge", jit=True)
pad_width = ((0, 0),) * (v.ndim - ndims) + ((0, 1),) * ndims
P = Pad(v.shape, pad_width=pad_width, mode="edge", jit=True)
self.WP = self.W @ P
C = Crop(crop_width=(((0, 1),) * ndims), input_shape=w_input_shape, jit=True)
C = Crop(crop_width=pad_width, input_shape=w_input_shape, jit=True)
self.CWT = C @ self.W.T

if self.circular:
Expand All @@ -152,9 +157,13 @@ def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
# Haar transform of padded input
WPv = self.WP(v)
slce = (
1,
snp.s_[:],
) + (snp.s_[:-1],) * ndims
(
1,
snp.s_[:],
)
+ (snp.s_[:],) * (v.ndim - ndims)
+ (snp.s_[:-1],) * ndims
)
WPv = WPv.at[slce].set(self.norm.prox(WPv[slce], snp.sqrt(2) * K * lam))
u = (1.0 / K) * self.CWT(WPv)

Expand Down
60 changes: 60 additions & 0 deletions scico/test/functional/test_tvnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,63 @@ def test_2d(self, tvtype, circular):

assert metric.snr(x_tvdn, x_aprx) > 50
assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6


class Test3D:
def setup_method(self):
N = 32
x2d = create_circular_phantom(
(N, N), [0.6 * N, 0.4 * N, 0.2 * N, 0.1 * N], [0.25, 1, 0, 0.5]
)
gr, gc = np.ogrid[0:N, 0:N]
x2d += (gr + gc) / (4 * N)
x_gt = np.stack((0.9 * x2d, np.zeros(x2d.shape), 1.1 * x2d))
σ = 0.02
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise
self.x_gt = x_gt
self.y = y

@pytest.mark.parametrize("circular", [False])
@pytest.mark.parametrize("tvtype", ["iso"])
def test_3d(self, tvtype, circular):
x_gt = self.x_gt
y = self.y

λ = 5e-2
f = loss.SquaredL2Loss(y=y)
if tvtype == "aniso":
g = λ * functional.L1Norm()
else:
g = λ * functional.L21Norm()
C = linop.FiniteDifference(
input_shape=x_gt.shape, axes=(1, 2), circular=circular, append=None if circular else 0
)

solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[5e0],
x0=y,
maxiter=150,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}),
)
x_tvdn = solver.solve()

if tvtype == "aniso":
h = λ * functional.AnisotropicTVNorm(circular=circular, ndims=2)
else:
h = λ * functional.IsotropicTVNorm(circular=circular, ndims=2)

solver = AcceleratedPGM(
f=f,
g=h,
L0=1e3,
x0=y,
maxiter=400,
)
x_aprx = solver.solve()

assert metric.snr(x_tvdn, x_aprx) > 50
assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6

0 comments on commit 5ab605f

Please sign in to comment.