From 92e3193624e6ee93fd7baaee08de4dfe42279179 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 17 Dec 2024 17:07:25 +0100 Subject: [PATCH] Move trajectory scaling into KTrajectory (#582) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Schuenke --- src/mrpro/data/KTrajectory.py | 45 +++++++++++++---- src/mrpro/data/KTrajectoryRawShape.py | 49 +++++++++++++++++++ .../traj_calculators/KTrajectoryPulseq.py | 28 ++++------- tests/data/_PulseqRadialTestSeq.py | 4 +- tests/data/test_traj_calculators.py | 8 +-- 5 files changed, 101 insertions(+), 33 deletions(-) diff --git a/src/mrpro/data/KTrajectory.py b/src/mrpro/data/KTrajectory.py index 8cbe207cc..f8bab4bae 100644 --- a/src/mrpro/data/KTrajectory.py +++ b/src/mrpro/data/KTrajectory.py @@ -1,6 +1,7 @@ """KTrajectory dataclass.""" from dataclasses import dataclass +from typing import Literal import numpy as np import torch @@ -8,6 +9,7 @@ from mrpro.data.enums import TrajType from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.SpatialDimension import SpatialDimension from mrpro.utils import remove_repeat from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues @@ -69,29 +71,52 @@ def from_tensor( cls, tensor: torch.Tensor, stack_dim: int = 0, - repeat_detection_tolerance: float | None = 1e-8, + axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx', + repeat_detection_tolerance: float | None = 1e-6, grid_detection_tolerance: float = 1e-3, + scaling_matrix: SpatialDimension | None = None, ) -> Self: """Create a KTrajectory from a tensor representation of the trajectory. - Reduces repeated dimensions to singletons if repeat_detection_tolerance - is not set to None. - + Reduces repeated dimensions to singletons if repeat_detection_tolerance is not set to None. Parameters ---------- tensor The tensor representation of the trajectory. - This should be a 5-dim tensor, with (kz,ky,kx) stacked in this order along stack_dim + This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`. stack_dim - The dimension in the tensor the directions have been stacked along. + The dimension in the tensor along which the directions are stacked. + axes_order + The order of the axes in the tensor. The MRpro convention is 'zyx'. repeat_detection_tolerance - detects if broadcasting can be used, i.e. if dimensions are repeated. - Set to None to disable. + Tolerance for detecting repeated dimensions (broadcasting). + If trajectory points differ by less than this value, they are considered identical. + Set to None to disable this feature. grid_detection_tolerance - tolerance to detect if trajectory points are on integer grid positions + Tolerance for detecting whether trajectory points align with integer grid positions. + This tolerance is applied after rescaling if `scaling_matrix` is provided. + scaling_matrix + If a scaling matrix is provided, the trajectory is rescaled to fit within + the dimensions of the matrix. If not provided, the trajectory remains unchanged. + """ - kz, ky, kx = torch.unbind(tensor, dim=stack_dim) + ks = tensor.unbind(dim=stack_dim) + kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx') + + def rescale(k: torch.Tensor, size: float) -> torch.Tensor: + max_abs_range = 2 * k.abs().max() + if size < 2 or max_abs_range < 1e-6: + # a single encoding point should be at zero + # avoid division by zero + return torch.zeros_like(k) + return k * (size / max_abs_range) + + if scaling_matrix is not None: + kz = rescale(kz, scaling_matrix.z) + ky = rescale(ky, scaling_matrix.y) + kx = rescale(kx, scaling_matrix.x) + return cls( kz, ky, diff --git a/src/mrpro/data/KTrajectoryRawShape.py b/src/mrpro/data/KTrajectoryRawShape.py index 3730e6669..e75ad8c16 100644 --- a/src/mrpro/data/KTrajectoryRawShape.py +++ b/src/mrpro/data/KTrajectoryRawShape.py @@ -1,13 +1,16 @@ """KTrajectoryRawShape dataclass.""" from dataclasses import dataclass +from typing import Literal import numpy as np import torch from einops import rearrange +from typing_extensions import Self from mrpro.data.KTrajectory import KTrajectory from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.SpatialDimension import SpatialDimension @dataclass(slots=True, frozen=True) @@ -32,6 +35,52 @@ class KTrajectoryRawShape(MoveDataMixin): repeat_detection_tolerance: None | float = 1e-3 """tolerance for repeat detection. Set to None to disable.""" + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + stack_dim: int = 0, + axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx', + repeat_detection_tolerance: float | None = 1e-6, + scaling_matrix: SpatialDimension | None = None, + ) -> Self: + """Create a KTrajectoryRawShape from a tensor representation of the trajectory. + + Parameters + ---------- + tensor + The tensor representation of the trajectory. + This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`. + stack_dim + The dimension in the tensor along which the directions are stacked. + axes_order + The order of the axes in the tensor. The MRpro convention is 'zyx'. + repeat_detection_tolerance + Tolerance for detecting repeated dimensions (broadcasting). + If trajectory points differ by less than this value, they are considered identical. + Set to None to disable this feature. + scaling_matrix + If a scaling matrix is provided, the trajectory is rescaled to fit within + the dimensions of the matrix. If not provided, the trajectory remains unchanged. + """ + ks = tensor.unbind(dim=stack_dim) + kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx') + + def rescale(k: torch.Tensor, size: float) -> torch.Tensor: + max_abs_range = 2 * k.abs().max() + if size < 2 or max_abs_range < 1e-6: + # a single encoding point should be at zero + # avoid division by zero + return torch.zeros_like(k) + return k * (size / max_abs_range) + + if scaling_matrix is not None: + kz = rescale(kz, scaling_matrix.z) + ky = rescale(ky, scaling_matrix.y) + kx = rescale(kx, scaling_matrix.x) + + return cls(kz, ky, kx, repeat_detection_tolerance=repeat_detection_tolerance) + def sort_and_reshape( self, sort_idx: np.ndarray, diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 598aa2184..a446464a8 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -2,7 +2,6 @@ from pathlib import Path -import pypulseq as pp import torch from einops import rearrange @@ -40,8 +39,10 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: ------- trajectory of type KTrajectoryRawShape """ + from pypulseq import Sequence + # create PyPulseq Sequence object and read .seq file - seq = pp.Sequence() + seq = Sequence() seq.read(file_path=str(self.seq_path)) # calculate k-space trajectory using PyPulseq @@ -52,20 +53,11 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: n_samples = torch.unique(n_samples) if len(n_samples) > 1: raise ValueError('We currently only support constant number of samples') - n_k0 = int(n_samples.item()) - - def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int): - if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0: - k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj))) - else: - # We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences. - # However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases. - k_traj = torch.zeros_like(k_traj) - return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0) - - # rearrange k-space trajectory to match MRpro convention - kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x) - ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y) - kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z) - return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance) + k_traj_reshaped = rearrange(k_traj_adc, 'xyz (other k0) -> xyz other k0', k0=int(n_samples.item())) + return KTrajectoryRawShape.from_tensor( + k_traj_reshaped, + axes_order='xyz', + scaling_matrix=kheader.encoding_matrix, + repeat_detection_tolerance=self.repeat_detection_tolerance, + ) diff --git a/tests/data/_PulseqRadialTestSeq.py b/tests/data/_PulseqRadialTestSeq.py index 82cab5577..0cd9b6032 100644 --- a/tests/data/_PulseqRadialTestSeq.py +++ b/tests/data/_PulseqRadialTestSeq.py @@ -29,7 +29,9 @@ def __init__(self, seq_filename: str, n_x=256, n_spokes=10): system = pypulseq.Opts() rf, gz, _ = pypulseq.make_sinc_pulse(flip_angle=0.1, slice_thickness=1e-3, system=system, return_gz=True) - gx = pypulseq.make_trapezoid(channel='x', flat_area=n_x * delta_k, flat_time=2e-3, system=system) + gx = pypulseq.make_trapezoid( + channel='x', flat_area=n_x * delta_k, flat_time=n_x * system.grad_raster_time, system=system + ) adc = pypulseq.make_adc(num_samples=n_x, duration=gx.flat_time, delay=gx.rise_time, system=system) gx_pre = pypulseq.make_trapezoid(channel='x', area=-gx.area / 2 - delta_k / 2, duration=2e-3, system=system) gz_reph = pypulseq.make_trapezoid(channel='z', area=-gz.area / 2, duration=2e-3, system=system) diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 7ddcb30aa..8ac3d0b71 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -260,11 +260,11 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_ trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename) trajectory = trajectory_calculator(kheader=valid_rad2d_kheader) - kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0) - kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test))) + kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze() + kx_test = kx_test * valid_rad2d_kheader.encoding_matrix.x / (2 * kx_test.abs().max()) - ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze(0).squeeze(0) - ky_test *= valid_rad2d_kheader.encoding_matrix.y / (2 * torch.max(torch.abs(ky_test))) + ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze() + ky_test = ky_test * valid_rad2d_kheader.encoding_matrix.y / (2 * ky_test.abs().max()) torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3) torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3)