Skip to content

Commit

Permalink
Move trajectory scaling into KTrajectory (#582)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Schuenke <[email protected]>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent 4096ead commit 92e3193
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 33 deletions.
45 changes: 35 additions & 10 deletions src/mrpro/data/KTrajectory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""KTrajectory dataclass."""

from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from typing_extensions import Self

from mrpro.data.enums import TrajType
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils import remove_repeat
from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues

Expand Down Expand Up @@ -69,29 +71,52 @@ def from_tensor(
cls,
tensor: torch.Tensor,
stack_dim: int = 0,
repeat_detection_tolerance: float | None = 1e-8,
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
repeat_detection_tolerance: float | None = 1e-6,
grid_detection_tolerance: float = 1e-3,
scaling_matrix: SpatialDimension | None = None,
) -> Self:
"""Create a KTrajectory from a tensor representation of the trajectory.
Reduces repeated dimensions to singletons if repeat_detection_tolerance
is not set to None.
Reduces repeated dimensions to singletons if repeat_detection_tolerance is not set to None.
Parameters
----------
tensor
The tensor representation of the trajectory.
This should be a 5-dim tensor, with (kz,ky,kx) stacked in this order along stack_dim
This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`.
stack_dim
The dimension in the tensor the directions have been stacked along.
The dimension in the tensor along which the directions are stacked.
axes_order
The order of the axes in the tensor. The MRpro convention is 'zyx'.
repeat_detection_tolerance
detects if broadcasting can be used, i.e. if dimensions are repeated.
Set to None to disable.
Tolerance for detecting repeated dimensions (broadcasting).
If trajectory points differ by less than this value, they are considered identical.
Set to None to disable this feature.
grid_detection_tolerance
tolerance to detect if trajectory points are on integer grid positions
Tolerance for detecting whether trajectory points align with integer grid positions.
This tolerance is applied after rescaling if `scaling_matrix` is provided.
scaling_matrix
If a scaling matrix is provided, the trajectory is rescaled to fit within
the dimensions of the matrix. If not provided, the trajectory remains unchanged.
"""
kz, ky, kx = torch.unbind(tensor, dim=stack_dim)
ks = tensor.unbind(dim=stack_dim)
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')

def rescale(k: torch.Tensor, size: float) -> torch.Tensor:
max_abs_range = 2 * k.abs().max()
if size < 2 or max_abs_range < 1e-6:
# a single encoding point should be at zero
# avoid division by zero
return torch.zeros_like(k)
return k * (size / max_abs_range)

if scaling_matrix is not None:
kz = rescale(kz, scaling_matrix.z)
ky = rescale(ky, scaling_matrix.y)
kx = rescale(kx, scaling_matrix.x)

return cls(
kz,
ky,
Expand Down
49 changes: 49 additions & 0 deletions src/mrpro/data/KTrajectoryRawShape.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""KTrajectoryRawShape dataclass."""

from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from einops import rearrange
from typing_extensions import Self

from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.SpatialDimension import SpatialDimension


@dataclass(slots=True, frozen=True)
Expand All @@ -32,6 +35,52 @@ class KTrajectoryRawShape(MoveDataMixin):
repeat_detection_tolerance: None | float = 1e-3
"""tolerance for repeat detection. Set to None to disable."""

@classmethod
def from_tensor(
cls,
tensor: torch.Tensor,
stack_dim: int = 0,
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
repeat_detection_tolerance: float | None = 1e-6,
scaling_matrix: SpatialDimension | None = None,
) -> Self:
"""Create a KTrajectoryRawShape from a tensor representation of the trajectory.
Parameters
----------
tensor
The tensor representation of the trajectory.
This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`.
stack_dim
The dimension in the tensor along which the directions are stacked.
axes_order
The order of the axes in the tensor. The MRpro convention is 'zyx'.
repeat_detection_tolerance
Tolerance for detecting repeated dimensions (broadcasting).
If trajectory points differ by less than this value, they are considered identical.
Set to None to disable this feature.
scaling_matrix
If a scaling matrix is provided, the trajectory is rescaled to fit within
the dimensions of the matrix. If not provided, the trajectory remains unchanged.
"""
ks = tensor.unbind(dim=stack_dim)
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')

def rescale(k: torch.Tensor, size: float) -> torch.Tensor:
max_abs_range = 2 * k.abs().max()
if size < 2 or max_abs_range < 1e-6:
# a single encoding point should be at zero
# avoid division by zero
return torch.zeros_like(k)
return k * (size / max_abs_range)

if scaling_matrix is not None:
kz = rescale(kz, scaling_matrix.z)
ky = rescale(ky, scaling_matrix.y)
kx = rescale(kx, scaling_matrix.x)

return cls(kz, ky, kx, repeat_detection_tolerance=repeat_detection_tolerance)

def sort_and_reshape(
self,
sort_idx: np.ndarray,
Expand Down
28 changes: 10 additions & 18 deletions src/mrpro/data/traj_calculators/KTrajectoryPulseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pathlib import Path

import pypulseq as pp
import torch
from einops import rearrange

Expand Down Expand Up @@ -40,8 +39,10 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
-------
trajectory of type KTrajectoryRawShape
"""
from pypulseq import Sequence

# create PyPulseq Sequence object and read .seq file
seq = pp.Sequence()
seq = Sequence()
seq.read(file_path=str(self.seq_path))

# calculate k-space trajectory using PyPulseq
Expand All @@ -52,20 +53,11 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
n_samples = torch.unique(n_samples)
if len(n_samples) > 1:
raise ValueError('We currently only support constant number of samples')
n_k0 = int(n_samples.item())

def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int):
if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0:
k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj)))
else:
# We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences.
# However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases.
k_traj = torch.zeros_like(k_traj)
return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0)

# rearrange k-space trajectory to match MRpro convention
kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z)

return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance)
k_traj_reshaped = rearrange(k_traj_adc, 'xyz (other k0) -> xyz other k0', k0=int(n_samples.item()))
return KTrajectoryRawShape.from_tensor(
k_traj_reshaped,
axes_order='xyz',
scaling_matrix=kheader.encoding_matrix,
repeat_detection_tolerance=self.repeat_detection_tolerance,
)
4 changes: 3 additions & 1 deletion tests/data/_PulseqRadialTestSeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self, seq_filename: str, n_x=256, n_spokes=10):

system = pypulseq.Opts()
rf, gz, _ = pypulseq.make_sinc_pulse(flip_angle=0.1, slice_thickness=1e-3, system=system, return_gz=True)
gx = pypulseq.make_trapezoid(channel='x', flat_area=n_x * delta_k, flat_time=2e-3, system=system)
gx = pypulseq.make_trapezoid(
channel='x', flat_area=n_x * delta_k, flat_time=n_x * system.grad_raster_time, system=system
)
adc = pypulseq.make_adc(num_samples=n_x, duration=gx.flat_time, delay=gx.rise_time, system=system)
gx_pre = pypulseq.make_trapezoid(channel='x', area=-gx.area / 2 - delta_k / 2, duration=2e-3, system=system)
gz_reph = pypulseq.make_trapezoid(channel='z', area=-gz.area / 2, duration=2e-3, system=system)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_traj_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_
trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename)
trajectory = trajectory_calculator(kheader=valid_rad2d_kheader)

kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0)
kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test)))
kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze()
kx_test = kx_test * valid_rad2d_kheader.encoding_matrix.x / (2 * kx_test.abs().max())

ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze(0).squeeze(0)
ky_test *= valid_rad2d_kheader.encoding_matrix.y / (2 * torch.max(torch.abs(ky_test)))
ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze()
ky_test = ky_test * valid_rad2d_kheader.encoding_matrix.y / (2 * ky_test.abs().max())

torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3)
torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3)

3 comments on commit 92e3193

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py811285%108–113, 116–118, 203–207
   MoveDataMixin.py1401887%15, 113, 129, 143–145, 207, 323–325, 338, 417, 437–438, 440, 455–456, 458
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2322191%34, 104, 141, 148, 154, 274–276, 289–291, 325, 343, 356, 369, 382, 395, 404–405, 420, 429
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1341887%109–110, 125, 132, 142, 150, 204–205, 243, 248–249, 268–279
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py23196%55
src/mrpro/operators
   CartesianSamplingOp.py89397%118, 157, 280
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   FourierOp.py158398%263, 381, 386
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1681094%55, 91, 190, 220, 261, 270, 278, 287, 295, 320
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   reshape.py60198%191
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL493336093% 

Tests Skipped Failures Errors Time
2262 0 💤 0 ❌ 0 🔥 1m 48s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py811285%108–113, 116–118, 203–207
   MoveDataMixin.py1401887%15, 113, 129, 143–145, 207, 323–325, 338, 417, 437–438, 440, 455–456, 458
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2322191%34, 104, 141, 148, 154, 274–276, 289–291, 325, 343, 356, 369, 382, 395, 404–405, 420, 429
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1341887%109–110, 125, 132, 142, 150, 204–205, 243, 248–249, 268–279
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py23196%55
src/mrpro/operators
   CartesianSamplingOp.py89397%118, 157, 280
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   FourierOp.py158398%263, 381, 386
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1681094%55, 91, 190, 220, 261, 270, 278, 287, 295, 320
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   reshape.py60198%191
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL493336093% 

Tests Skipped Failures Errors Time
2262 0 💤 0 ❌ 0 🔥 1m 55s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py811285%108–113, 116–118, 203–207
   MoveDataMixin.py1401887%15, 113, 129, 143–145, 207, 323–325, 338, 417, 437–438, 440, 455–456, 458
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2322191%34, 104, 141, 148, 154, 274–276, 289–291, 325, 343, 356, 369, 382, 395, 404–405, 420, 429
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1341887%109–110, 125, 132, 142, 150, 204–205, 243, 248–249, 268–279
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py23196%55
src/mrpro/operators
   CartesianSamplingOp.py89397%118, 157, 280
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   FourierOp.py158398%263, 381, 386
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1681094%55, 91, 190, 220, 261, 270, 278, 287, 295, 320
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   reshape.py60198%191
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL493336093% 

Tests Skipped Failures Errors Time
2262 0 💤 0 ❌ 0 🔥 2m 11s ⏱️

Please sign in to comment.