From c4ebb685efaec40ff176a1f5ecd0da7b2851f9b4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 23 Jan 2025 10:36:49 -0700 Subject: [PATCH] Add a utility function for rotating CT volumes (#574) * Add utility function * Fix docs * Add test * Bump jax version upper bound * Bump max supported jaxlib/jax version * Update submodule * Bump test tolerances * Update copyright year * Looser test tolerances on GPU * Looser test tolerances required for GPU --------- Co-authored-by: Brendt Wohlberg --- CHANGES.rst | 2 +- LICENSE | 2 +- README.md | 2 +- docs/source/conf/10-project.py | 2 +- requirements.txt | 8 ++-- scico/linop/xray/__init__.py | 8 ++-- scico/linop/xray/_util.py | 58 +++++++++++++++++++++++++ scico/test/linop/test_linop.py | 7 ++- scico/test/linop/xray/test_astra.py | 6 +-- scico/test/linop/xray/test_xray_util.py | 12 +++++ 10 files changed, 87 insertions(+), 20 deletions(-) create mode 100644 scico/linop/xray/_util.py create mode 100644 scico/test/linop/xray/test_xray_util.py diff --git a/CHANGES.rst b/CHANGES.rst index 8f24033a0..f4682bd04 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,7 +7,7 @@ Version 0.0.7 (unreleased) ---------------------------- • New module ``scico.trace`` for tracing function/method calls. -• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.4.37. +• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.5.0. • Support ``flax`` versions 0.8.0 to 0.10.2. diff --git a/LICENSE b/LICENSE index a98a6e1f5..fd9f379c4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2021-2024, Los Alamos National Laboratory +Copyright (c) 2021-2025, Los Alamos National Laboratory All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index cd333e647..d21d2e671 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ License (see the `LICENSE` file for details). LANL open source approval reference C20091. -\(c\) 2020-2024. Triad National Security, LLC. All rights reserved. This +\(c\) 2020-2025. Triad National Security, LLC. All rights reserved. This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. Department of Energy/National diff --git a/docs/source/conf/10-project.py b/docs/source/conf/10-project.py index 1690f9dc6..98eca4b72 100644 --- a/docs/source/conf/10-project.py +++ b/docs/source/conf/10-project.py @@ -2,7 +2,7 @@ # General information about the project. project = "SCICO" -copyright = "2020-2024, SCICO Developers" +copyright = "2020-2025, SCICO Developers" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/requirements.txt b/requirements.txt index a59cd60b5..a313ec8b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ typing_extensions -numpy>=1.20.0 -scipy>=1.6.0 +numpy>=1.25.0 +scipy>=1.11.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.13,<=0.4.37 -jax>=0.4.13,<=0.4.37 +jaxlib>=0.4.13,<=0.5.0 +jax>=0.4.13,<=0.5.0 orbax-checkpoint>=0.5.0 flax>=0.8.0,<=0.10.2 pyabel>=0.9.0 diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py index 75c66368b..b1288ce18 100644 --- a/scico/linop/xray/__init__.py +++ b/scico/linop/xray/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2023-2024 by SCICO Developers +# Copyright (C) 2023-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -66,12 +66,10 @@ import sys +from ._util import rotate_volume from ._xray import XRayTransform2D, XRayTransform3D -__all__ = [ - "XRayTransform2D", - "XRayTransform3D", -] +__all__ = ["XRayTransform2D", "XRayTransform3D", "rotate_volume"] # Imported items in __all__ appear to originate in top-level xray module diff --git a/scico/linop/xray/_util.py b/scico/linop/xray/_util.py new file mode 100644 index 000000000..08701a63f --- /dev/null +++ b/scico/linop/xray/_util.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2024-2025 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Utilities for CT data.""" + +from typing import Optional + +import jax.numpy as jnp +from jax.scipy.ndimage import map_coordinates +from jax.scipy.spatial.transform import Rotation +from jax.typing import ArrayLike + + +def rotate_volume( + vol: ArrayLike, + rot: Rotation, + x: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, + z: Optional[ArrayLike] = None, + center: Optional[ArrayLike] = None, +): + """Rotate a 3D array. + + Rotate a 3D array as specified by an instance of + :class:`~jax.scipy.spatial.transform.Rotation`. Any axis coordinates + that are not specified default to a range corresponding to the size + of the array on that axis, starting at zero. + + Args: + vol: Array to be rotated. + rot: Rotation specification. + x: Coordinates for :code:`x` axis (axis 0). + y: Coordinates for :code:`y` axis (axis 1). + z: Coordinates for :code:`z` axis (axis 2). + center: A 3-vector specifying the center of rotation. + Defaults to the center of the array. + + Returns: + Rotated array. + """ + shape = vol.shape + if x is None: + x = jnp.arange(shape[0]) + if y is None: + y = jnp.arange(shape[1]) + if z is None: + z = jnp.arange(shape[2]) + if center is None: + center = (jnp.array(shape, dtype=jnp.float32) - 1.0) / 2.0 + gx, gy, gz = jnp.meshgrid(x - center[0], y - center[1], z - center[2], indexing="ij") + crd = jnp.stack((gx.ravel(), gy.ravel(), gz.ravel())) + rot_crd = rot.as_matrix() @ crd + center[:, jnp.newaxis] # faster than rot.apply(crd.T) + rot_vol = map_coordinates(vol, rot_crd.reshape((3,) + shape), order=1) + return rot_vol diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index 7f97f9caa..462e677a2 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -109,8 +109,7 @@ def test_scalar_left(testobj, operator, scalar): assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) - - np.testing.assert_allclose(comp_mat.conj().T @ testobj.y, comp_op.adj(testobj.y), rtol=5e-5) + np.testing.assert_allclose(comp_mat.conj().T @ testobj.y, comp_op.adj(testobj.y), rtol=1e-4) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) @@ -218,7 +217,7 @@ def test_transpose_matvec(testobj): assert a.dtype == testobj.A.dtype assert b.dtype == testobj.A.dtype - np.testing.assert_allclose(a, comp_mat, rtol=5e-5) + np.testing.assert_allclose(a, comp_mat, rtol=1e-4) np.testing.assert_allclose(a, b, rtol=5e-5) @@ -264,7 +263,7 @@ def test_adjoint_matvec(testobj): assert a.dtype == testobj.A.dtype assert b.dtype == testobj.A.dtype assert c.dtype == testobj.A.dtype - np.testing.assert_allclose(a, comp_mat, rtol=5e-5) + np.testing.assert_allclose(a, comp_mat, rtol=1e-4) np.testing.assert_allclose(a, b, rtol=5e-5) np.testing.assert_allclose(a, c, rtol=5e-5) diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index f2cb156f6..d0eb39ef9 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -24,9 +24,9 @@ N = 128 -RTOL_CPU = 5e-5 -RTOL_GPU = 7e-2 -RTOL_GPU_RANDOM_INPUT = 1.0 +RTOL_CPU = 1e-4 +RTOL_GPU = 1e-1 +RTOL_GPU_RANDOM_INPUT = 2.0 def make_im(Nx, Ny, is_3d=True): diff --git a/scico/test/linop/xray/test_xray_util.py b/scico/test/linop/xray/test_xray_util.py new file mode 100644 index 000000000..dc1a447d7 --- /dev/null +++ b/scico/test/linop/xray/test_xray_util.py @@ -0,0 +1,12 @@ +import numpy as np + +from jax.scipy.spatial.transform import Rotation + +from scico.linop.xray import rotate_volume + + +def test_rotate_volume(): + vol = np.arange(27).reshape((3, 3, 3)) + rot = Rotation.from_euler("XY", [90, 90], degrees=True) + vol_rot = rotate_volume(vol, rot) + np.testing.assert_allclose(vol.transpose((1, 2, 0)), vol_rot, rtol=1e-7)