Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move scaling to KTrajectory #582

Merged
merged 23 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/mrpro/data/KTrajectory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""KTrajectory dataclass."""

from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from typing_extensions import Self

from mrpro.data.enums import TrajType
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils import remove_repeat
from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues

Expand Down Expand Up @@ -69,8 +71,10 @@ def from_tensor(
cls,
tensor: torch.Tensor,
stack_dim: int = 0,
repeat_detection_tolerance: float | None = 1e-8,
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
repeat_detection_tolerance: float | None = 1e-6,
grid_detection_tolerance: float = 1e-3,
encoding_matrix: SpatialDimension | None = None,
) -> Self:
"""Create a KTrajectory from a tensor representation of the trajectory.

Expand All @@ -85,13 +89,33 @@ def from_tensor(
This should be a 5-dim tensor, with (kz,ky,kx) stacked in this order along stack_dim
stack_dim
The dimension in the tensor the directions have been stacked along.
axes_order
Order of the axes in the tensor. Our convention usually is zyx order.
repeat_detection_tolerance
detects if broadcasting can be used, i.e. if dimensions are repeated.
Set to None to disable.
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
grid_detection_tolerance
tolerance to detect if trajectory points are on integer grid positions
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
encoding_matrix
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
if an encoding matrix is supplied, the trajectory is rescaled to fit
within the matrix. Otherwise, it is left as-is.
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
"""
kz, ky, kx = torch.unbind(tensor, dim=stack_dim)
ks = tensor.unbind(dim=stack_dim)
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')

def normalize(k: torch.Tensor, encoding_size: int) -> torch.Tensor:
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
max_abs_range = 2 * k.abs().max()
if encoding_size == 1 or max_abs_range < 1e-6:
# a single encoding point should be at zero
# avoid division by zero
return torch.zeros_like(k)
return k * (encoding_size / max_abs_range)

if encoding_matrix is not None:
kz = normalize(kz, encoding_matrix.z)
ky = normalize(ky, encoding_matrix.y)
kx = normalize(kx, encoding_matrix.x)

fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
return cls(
kz,
ky,
Expand Down
53 changes: 53 additions & 0 deletions src/mrpro/data/KTrajectoryRawShape.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""KTrajectoryRawShape dataclass."""

from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from einops import rearrange
from typing_extensions import Self

from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.SpatialDimension import SpatialDimension


@dataclass(slots=True, frozen=True)
Expand All @@ -32,6 +35,56 @@ class KTrajectoryRawShape(MoveDataMixin):
repeat_detection_tolerance: None | float = 1e-3
"""tolerance for repeat detection. Set to None to disable."""

@classmethod
def from_tensor(
cls,
tensor: torch.Tensor,
stack_dim: int = 0,
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
repeat_detection_tolerance: float | None = 1e-6,
encoding_matrix: SpatialDimension | None = None,
) -> Self:
"""Create a KTrajectoryRawShape from a tensor representation of the trajectory.

Parameters
----------
tensor
The tensor representation of the trajectory.
This should be a 5-dim tensor, with (kz,ky,kx) stacked in this order along stack_dim
stack_dim
The dimension in the tensor the directions have been stacked along.
axes_order
Order of the axes in the tensor. Our convention usually is zyx order.
repeat_detection_tolerance
detects if broadcasting can be used, i.e. if dimensions are repeated.
Set to None to disable.
encoding_matrix
if an encoding matrix is supplied, the trajectory is rescaled to fit
within the matrix. Otherwise, it is left as-is.
"""
ks = tensor.unbind(dim=stack_dim)
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')

def normalize(k: torch.Tensor, encoding_size: int) -> torch.Tensor:
max_abs_range = 2 * k.abs().max()
if encoding_size == 1 or max_abs_range < 1e-6:
# a single encoding point should be at zero
# avoid division by zero
return torch.zeros_like(k)
return k * (encoding_size / max_abs_range)

if encoding_matrix is not None:
kz = normalize(kz, encoding_matrix.z)
ky = normalize(ky, encoding_matrix.y)
kx = normalize(kx, encoding_matrix.x)

return cls(
kz,
ky,
kx,
repeat_detection_tolerance=repeat_detection_tolerance,
)

def sort_and_reshape(
self,
sort_idx: np.ndarray,
Expand Down
28 changes: 10 additions & 18 deletions src/mrpro/data/traj_calculators/KTrajectoryPulseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pathlib import Path

import pypulseq as pp
import torch
from einops import rearrange

Expand Down Expand Up @@ -40,8 +39,10 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
-------
trajectory of type KTrajectoryRawShape
"""
from pypulseq import Sequence

# create PyPulseq Sequence object and read .seq file
seq = pp.Sequence()
seq = Sequence()
seq.read(file_path=str(self.seq_path))

# calculate k-space trajectory using PyPulseq
Expand All @@ -52,20 +53,11 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
n_samples = torch.unique(n_samples)
if len(n_samples) > 1:
raise ValueError('We currently only support constant number of samples')
n_k0 = int(n_samples.item())

def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int):
if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0:
k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj)))
else:
# We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences.
# However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases.
k_traj = torch.zeros_like(k_traj)
return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0)

# rearrange k-space trajectory to match MRpro convention
kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z)

return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance)
k_traj_reshaped = rearrange(k_traj_adc, 'xyz (other k0) -> xyz other k0', k0=int(n_samples.item()))
return KTrajectoryRawShape.from_tensor(
k_traj_reshaped,
axes_order='xyz',
encoding_matrix=kheader.encoding_matrix,
repeat_detection_tolerance=self.repeat_detection_tolerance,
)
4 changes: 3 additions & 1 deletion tests/data/_PulseqRadialTestSeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self, seq_filename: str, n_x=256, n_spokes=10):

system = pypulseq.Opts()
rf, gz, _ = pypulseq.make_sinc_pulse(flip_angle=0.1, slice_thickness=1e-3, system=system, return_gz=True)
gx = pypulseq.make_trapezoid(channel='x', flat_area=n_x * delta_k, flat_time=2e-3, system=system)
gx = pypulseq.make_trapezoid(
channel='x', flat_area=n_x * delta_k, flat_time=n_x * system.grad_raster_time, system=system
)
adc = pypulseq.make_adc(num_samples=n_x, duration=gx.flat_time, delay=gx.rise_time, system=system)
gx_pre = pypulseq.make_trapezoid(channel='x', area=-gx.area / 2 - delta_k / 2, duration=2e-3, system=system)
gz_reph = pypulseq.make_trapezoid(channel='z', area=-gz.area / 2, duration=2e-3, system=system)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_traj_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_
trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename)
trajectory = trajectory_calculator(kheader=valid_rad2d_kheader)

kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0)
kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test)))
kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze()
kx_test = kx_test * valid_rad2d_kheader.encoding_matrix.x / (2 * kx_test.abs().max())

ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze(0).squeeze(0)
ky_test *= valid_rad2d_kheader.encoding_matrix.y / (2 * torch.max(torch.abs(ky_test)))
ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze()
ky_test = ky_test * valid_rad2d_kheader.encoding_matrix.y / (2 * ky_test.abs().max())

torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3)
torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3)
Loading