Skip to content

Commit

Permalink
Add utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jan 30, 2025
1 parent b8b8282 commit 14acf0c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
22 changes: 21 additions & 1 deletion scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2024 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -25,6 +25,8 @@
import jax
from jax.typing import ArrayLike

from scipy.spatial.transform import Rotation

try:
import astra
except ModuleNotFoundError as e:
Expand Down Expand Up @@ -729,6 +731,24 @@ def angle_to_vector(det_spacing: Tuple[float, float], angles: np.ndarray) -> np.
return vectors


def rotate_vectors(vectors: np.ndarray, rot: Rotation) -> np.ndarray:
"""Rotate geometry specification vectors.
Rotate ASTRA "parallel3d_vec" geometry specification vectors.
Args:
vectors: Array of geometry specification vectors.
rot: Rotation.
Returns:
Rotated geometry specification vectors.
"""
rot_vecs = vectors.copy()
for k in range(0, 12, 3):
rot_vecs[:, k : k + 3] = rot.apply(rot_vecs[:, k : k + 3])
return rot_vecs


def _ensure_writeable(x):
"""Ensure that `x.flags.writeable` is ``True``, copying if needed."""
if hasattr(x, "flags"): # x is a numpy array
Expand Down
10 changes: 10 additions & 0 deletions scico/test/linop/xray/test_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import scico.numpy as snp
from scico.linop import DiagonalStack
from scico.test.linop.test_linop import adjoint_test
from scipy.spatial.transform import Rotation

try:
from scico.linop.xray.astra import (
XRayTransform2D,
XRayTransform3D,
_ensure_writeable,
angle_to_vector,
rotate_vectors,
)
except ModuleNotFoundError as e:
if e.name == "astra":
Expand Down Expand Up @@ -211,6 +213,14 @@ def test_angle_to_vector():
assert vectors.shape == (angles.size, 12)


def test_rotate_vectors():
v0 = angle_to_vector([1.0, 1.0], np.linspace(0, np.pi / 2, 4, endpoint=False))
v1 = angle_to_vector([1.0, 1.0], np.linspace(np.pi / 2, np.pi, 4, endpoint=False))
r = Rotation.from_euler("z", np.pi / 2)
v0r = rotate_vectors(v0, r)
np.testing.assert_allclose(v1, v0r, atol=1e-7)


## conversion functions
@pytest.fixture(scope="module")
def test_geometry():
Expand Down

0 comments on commit 14acf0c

Please sign in to comment.