diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 3ccd5d96..04bdabe8 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -18,6 +18,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike +import scico.numpy as snp from scico.numpy.util import is_scalar_equiv from scico.typing import Shape from scipy.spatial.transform import Rotation @@ -115,17 +116,19 @@ def __init__( adj_fn=self.back_project, ) - def project(self, im): + def project(self, im: ArrayLike) -> snp.Array: """Compute X-ray projection.""" return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) - def back_project(self, y): + def back_project(self, y: ArrayLike) -> snp.Array: """Compute X-ray back projection""" return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) @staticmethod @partial(jax.jit, static_argnames=["ny"]) - def _project(im, x0, dx, y0, ny, angles): + def _project( + im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike + ) -> snp.Array: r""" Args: im: Input array, (M, N). @@ -155,7 +158,9 @@ def _project(im, x0, dx, y0, ny, angles): @staticmethod @partial(jax.jit, static_argnames=["nx"]) - def _back_project(y, x0, dx, nx, y0, angles): + def _back_project( + y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike + ) -> ArrayLike: r""" Args: y: Input projection, (num_angles, N). @@ -184,7 +189,9 @@ def _back_project(y, x0, dx, nx, y0, angles): @staticmethod @partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) - def _calc_weights(x0, dx, nx, angle, y0): + def _calc_weights( + x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float + ) -> snp.Array: """ Args: @@ -263,28 +270,27 @@ def __init__( det_shape: Shape of detector. """ - self.input_shape = input_shape + self.input_shape: Shape = input_shape self.matrices = matrices self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) - super().__init__( - input_shape=self.input_shape, + input_shape=input_shape, output_shape=self.output_shape, eval_fn=self.project, adj_fn=self.back_project, ) - def project(self, im): + def project(self, im: ArrayLike) -> snp.Array: """Compute X-ray projection.""" return XRayTransform3D._project(im, self.matrices, self.det_shape) - def back_project(self, proj): + def back_project(self, proj: ArrayLike) -> snp.Array: """Compute X-ray back projection""" return XRayTransform3D._back_project(proj, self.matrices, self.input_shape) @staticmethod - def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: + def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: r""" Args: im: Input image. @@ -312,7 +318,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: @partial(jax.jit, donate_argnames="proj") def _project_single( im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0 - ) -> ArrayLike: + ) -> snp.Array: r""" Args: im: Input image. @@ -359,7 +365,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A @partial(jax.jit, donate_argnames="HTy") def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 - ) -> ArrayLike: + ) -> snp.Array: ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights( HTy.shape, matrix, y.shape, slice_offset ) @@ -370,7 +376,9 @@ def _back_project_single( return HTy @staticmethod - def _calc_weights(input_shape, matrix, output_shape, slice_offset: int = 0): + def _calc_weights( + input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0 + ) -> snp.Array: # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5) x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...) x = x.at[0].add(slice_offset) @@ -405,7 +413,7 @@ def matrices_from_euler_angles( degrees: bool = False, voxel_spacing: ArrayLike = None, det_spacing: ArrayLike = None, - ): + ) -> snp.Array: """ Create a set of projection matrices from Euler angles. The input voxels will undergo the specified rotation and then be @@ -450,6 +458,6 @@ def matrices_from_euler_angles( # add translation to line up the centers x0 = np.array(input_shape) / 2 t = -np.einsum("vmn,n->vm", matrices, x0) + np.array(output_shape) / 2 - matrices = np.concatenate((matrices, t[..., np.newaxis]), axis=2) + matrices = snp.concatenate((matrices, t[..., np.newaxis]), axis=2) return matrices diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 18f89237..3f50df6f 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -62,6 +62,15 @@ def set_astra_gpu_index(idx: Union[int, Sequence[int]]): def _project_coords( x_volume: np.ndarray, vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry ) -> np.ndarray: + """ + Transform volume (logical) coordinates into world coordinates based + on ASTRA geometry objects. + + Args: + x_volume: (..., 3) vector(s) of volume (AKA logical) coordinates + vol_geom: ASTRA volume geometry object. + proj_geom: ASTRA projection geometry object. + """ det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"]) x_world = volume_coords_to_world_coords(x_volume, vol_geom=vol_geom) x_dets = [] @@ -110,7 +119,7 @@ def project_world_coordinates( return ind_ij -def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry): +def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: """Convert a volume coordinate into a world coordinate. Convert a volume coordinate into a world coordinate using ASTRA @@ -131,6 +140,7 @@ def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry): def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: + """Convert a 2D volume coordinate into a 2D world coordinate.""" coord = idx[..., [2, 1]] # x:col, y:row, nx = np.array( # (x, y) order ( @@ -150,6 +160,7 @@ def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) - def _volume_index_to_astra_world_3d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: + """Convert a 3D volume coordinate into a 3D world coordinate.""" coord = idx[..., [2, 1, 0]] # x:col, y:row, z:slice nx = np.array( # (x, y, z) order (