Skip to content

Commit

Permalink
Add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jan 18, 2024
1 parent 64ef13d commit 12b0049
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion scico/test/linop/xray/test_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scico.test.linop.xray.test_svmbir import make_im

try:
from scico.linop.xray.astra import XRayTransform2D, XRayTransform3D
from scico.linop.xray.astra import XRayTransform2D, XRayTransform3D, angle_to_vector
except ModuleNotFoundError as e:
if e.name == "astra":
pytest.skip("astra not installed", allow_module_level=True)
Expand Down Expand Up @@ -148,3 +148,17 @@ def test_3D_on_GPU():
assert A.num_dims == 3
y = A @ x
ATy = A.T @ y


@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="GPU required for test")
def test_3D_api_equiv():
x = np.random.randn(4, 5, 6).astype(np.float32)
det_count = [7, 8]
det_spacing = [1.0, 1.5]
angles = snp.linspace(0, snp.pi, 10)
A = XRayTransform3D(x.shape, det_count=det_count, det_spacing=det_spacing, angles=angles)
vectors = angle_to_vector(det_spacing, angles)
B = XRayTransform3D(x.shape, det_count=det_count, vectors=vectors)
ya = A @ x
yb = B @ x
np.testing.assert_allclose(ya, yb, rtol=get_tol())

0 comments on commit 12b0049

Please sign in to comment.