From e1294682a7f5e0d749f8a89c99b98df6565eec50 Mon Sep 17 00:00:00 2001 From: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> Date: Tue, 9 May 2023 09:35:09 -0600 Subject: [PATCH] Make adjoint and gram get auto set separately (#414) * Add test * Have adjoint and gram get set separately * Bump test threshold to avoid CI test failure on OSX --------- Co-authored-by: Brendt Wohlberg --- scico/linop/_linop.py | 14 ++++++++++---- scico/linop/radon_astra.py | 4 ++-- scico/test/linop/test_radon_astra.py | 9 +++++++++ scico/test/linop/test_radon_svmbir.py | 2 +- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 3522ba628..9b3f2676e 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -154,18 +154,24 @@ def __init__( self.jit() def _set_adjoint(self): - """Automatically create adjoint and gram methods.""" + """Automatically create adjoint method.""" adj_fun = linear_adjoint(self.__call__, snp.zeros(self.input_shape, dtype=self.input_dtype)) self._adj = lambda x: adj_fun(x)[0] + + def _set_gram(self): + """Automatically create gram method.""" self._gram = lambda x: self.adj(self(x)) def jit(self): """Replace the private functions :meth:`._eval`, :meth:`_adj`, :meth:`._gram` with jitted versions. """ - if (self._adj is None) or (self._gram is None): + if self._adj is None: self._set_adjoint() + if self._gram is None: + self._set_gram() + self._eval = jax.jit(self._eval) self._adj = jax.jit(self._adj) self._gram = jax.jit(self._gram) @@ -374,7 +380,7 @@ def gram_op(self) -> LinearOperator: `G(x) = A.adj(A(x)))`. """ if self._gram is None: - self._set_adjoint() + self._set_gram() return LinearOperator( input_shape=self.input_shape, @@ -401,7 +407,7 @@ def gram( Result of `A.adj(A(x))`. """ if self._gram is None: - self._set_adjoint() + self._set_gram() assert self._gram is not None return self._gram(x) diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index 355f8ec3d..9238d56f3 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -86,10 +86,10 @@ def __init__( # Set up all the ASTRA config self.detector_spacing: float = detector_spacing self.det_count: int = det_count - self.angles: np.ndarray = angles + self.angles: np.ndarray = np.array(angles) self.proj_geom: dict = astra.create_proj_geom( - "parallel", detector_spacing, det_count, angles + "parallel", detector_spacing, det_count, self.angles ) self.proj_id: int self.input_shape: tuple = input_shape diff --git a/scico/test/linop/test_radon_astra.py b/scico/test/linop/test_radon_astra.py index 491c31ce7..3f5b20d44 100644 --- a/scico/test/linop/test_radon_astra.py +++ b/scico/test/linop/test_radon_astra.py @@ -5,6 +5,8 @@ import pytest import scico +import scico.numpy as snp +from scico.linop import DiagonalStack from scico.test.linop.test_linop import adjoint_test from scico.test.linop.test_radon_svmbir import make_im @@ -118,3 +120,10 @@ def test_adjoint_typical_input(testobj): x = make_im(A.input_shape[0], A.input_shape[1], is_3d=False) adjoint_test(A, x=x, rtol=get_tol()) + + +def test_jit_in_DiagonalStack(): + """See https://github.com/lanl/scico/issues/331""" + N = 10 + H = DiagonalStack([TomographicProjector((N, N), 1.0, N, snp.linspace(0, snp.pi, N))]) + H.T @ snp.zeros(H.output_shape, dtype=snp.float32) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 777b8725a..d9927af2c 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -374,4 +374,4 @@ def test_approx_prox( xprox_approx = snp.array(f_approx.prox(v, lam=λ, v0=xprox)) - assert snp.linalg.norm(xprox - xprox_approx) / snp.linalg.norm(xprox) < 4e-6 + assert snp.linalg.norm(xprox - xprox_approx) / snp.linalg.norm(xprox) < 5e-5