From 12b0049dcb020c44d67be759182c2c016469dc9d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jan 2024 13:52:13 -0700 Subject: [PATCH] Add a test --- scico/test/linop/xray/test_astra.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index bbacd7292..aa2157948 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -11,7 +11,7 @@ from scico.test.linop.xray.test_svmbir import make_im try: - from scico.linop.xray.astra import XRayTransform2D, XRayTransform3D + from scico.linop.xray.astra import XRayTransform2D, XRayTransform3D, angle_to_vector except ModuleNotFoundError as e: if e.name == "astra": pytest.skip("astra not installed", allow_module_level=True) @@ -148,3 +148,17 @@ def test_3D_on_GPU(): assert A.num_dims == 3 y = A @ x ATy = A.T @ y + + +@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="GPU required for test") +def test_3D_api_equiv(): + x = np.random.randn(4, 5, 6).astype(np.float32) + det_count = [7, 8] + det_spacing = [1.0, 1.5] + angles = snp.linspace(0, snp.pi, 10) + A = XRayTransform3D(x.shape, det_count=det_count, det_spacing=det_spacing, angles=angles) + vectors = angle_to_vector(det_spacing, angles) + B = XRayTransform3D(x.shape, det_count=det_count, vectors=vectors) + ya = A @ x + yb = B @ x + np.testing.assert_allclose(ya, yb, rtol=get_tol())