From 717c7fb1c62b382d446fc0a1f9d1620c098cfe4b Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 27 Nov 2024 14:29:24 +0100 Subject: [PATCH] Refactor AcqInfo and Header information ghstack-source-id: d1ac51bfb1f9c81a7bc7cdbcaf4c0f4e95de0959 ghstack-comment-id: 2501006700 Pull Request resolved: https://github.com/PTB-MR/mrpro/pull/560 --- src/mrpro/data/AcqInfo.py | 142 ++++---- src/mrpro/data/KData.py | 37 ++- src/mrpro/data/_kdata/KDataRemoveOsMixin.py | 6 - .../traj_calculators/KTrajectoryCalculator.py | 82 +++-- .../traj_calculators/KTrajectoryCartesian.py | 36 +- .../traj_calculators/KTrajectoryIsmrmrd.py | 6 +- .../traj_calculators/KTrajectoryPulseq.py | 39 ++- .../traj_calculators/KTrajectoryRadial2D.py | 35 +- .../data/traj_calculators/KTrajectoryRpe.py | 107 +++--- .../KTrajectorySunflowerGoldenRpe.py | 112 ++++--- tests/conftest.py | 7 +- tests/data/test_kdata.py | 8 +- tests/data/test_traj_calculators.py | 310 ++++++++---------- 13 files changed, 507 insertions(+), 420 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index f5d677f9..44d2c1e5 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass +from typing import overload import ismrmrd import numpy as np @@ -82,6 +83,37 @@ class AcqIdx(MoveDataMixin): """User index 7.""" +@dataclass(slots=True) +class UserValues(MoveDataMixin): + """User Values used in AcqInfo.""" + + float1: torch.Tensor + float2: torch.Tensor + float3: torch.Tensor + float4: torch.Tensor + float5: torch.Tensor + float6: torch.Tensor + float7: torch.Tensor + float8: torch.Tensor + int1: torch.Tensor + int2: torch.Tensor + int3: torch.Tensor + int4: torch.Tensor + int5: torch.Tensor + int6: torch.Tensor + int7: torch.Tensor + int8: torch.Tensor + + +@dataclass(slots=True) +class PhysiologyTimestamps: + """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units.""" + + timestamp1: torch.Tensor + timestamp2: torch.Tensor + timestamp3: torch.Tensor + + @dataclass(slots=True) class AcqInfo(MoveDataMixin): """Acquisition information for each readout.""" @@ -92,43 +124,19 @@ class AcqInfo(MoveDataMixin): acquisition_time_stamp: torch.Tensor """Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)""" - active_channels: torch.Tensor - """Number of active receiver coil elements.""" - - available_channels: torch.Tensor - """Number of available receiver coil elements.""" - - center_sample: torch.Tensor - """Index of the readout sample corresponding to k-space center (zero indexed).""" - - channel_mask: torch.Tensor - """Bit mask indicating active coils (64*16 = 1024 bits).""" - - discard_post: torch.Tensor - """Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events).""" - - discard_pre: torch.Tensor - """Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)""" - - encoding_space_ref: torch.Tensor - """Indexed reference to the encoding spaces enumerated in the MRD (xml) header.""" - flags: torch.Tensor """A bit mask of common attributes applicable to individual acquisition readouts.""" measurement_uid: torch.Tensor """Unique ID corresponding to the readout.""" - number_of_samples: torch.Tensor - """Number of sample points per readout (readouts may have different number of sample points).""" - orientation: Rotation """Rotation describing the orientation of the readout, phase and slice encoding direction.""" patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - physiology_time_stamp: torch.Tensor + physiology_time_stamps: PhysiologyTimestamps """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] @@ -140,26 +148,34 @@ class AcqInfo(MoveDataMixin): scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. - """Dimensionality of the k-space trajectory vector.""" - - user_float: torch.Tensor - """User-defined float parameters.""" + user: UserValues + """User defined float or int values""" - user_int: torch.Tensor - """User-defined int parameters.""" + @overload + @classmethod + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: None + ) -> Self: ... - version: torch.Tensor - """Major version number.""" + @overload + @classmethod + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: Sequence[str] + ) -> tuple[Self, tuple[torch.Tensor, ...]]: ... @classmethod - def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self: + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: Sequence[str] | None = None + ) -> Self | tuple[Self, tuple[torch.Tensor, ...]]: """Read the header of a list of acquisition and store information. Parameters ---------- - acquisitions: + acquisitions list of ismrmrd acquisistions to read from. Needs at least one acquisition. + additional_fields + if supplied, additional fields with these names will be from the ismrmrd acquisitions + and returned as tensors. """ # Idea: create array of structs, then a struct of arrays, # convert it into tensors to store in our dataclass. @@ -169,9 +185,9 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) raise ValueError('Acquisition list must not be empty.') # Creating the dtype first and casting to bytes - # is a workaround for a bug in cpython > 3.12 causing a warning - # is np.array(AcquisitionHeader) is called directly. - # also, this needs to check the dtyoe only once. + # is a workaround for a bug in cpython causing a warning + # if np.array(AcquisitionHeader) is called directly. + # also, this needs to check the dtype only once. acquisition_head_dtype = np.dtype(ismrmrd.AcquisitionHeader) headers = np.frombuffer( np.array([memoryview(a._head).cast('B') for a in acquisitions]), @@ -228,33 +244,49 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: user6=tensor(idx['user'][:, 6]), user7=tensor(idx['user'][:, 7]), ) - + user = UserValues( + tensor_2d(headers['user_float'][:, 0]), + tensor_2d(headers['user_float'][:, 1]), + tensor_2d(headers['user_float'][:, 2]), + tensor_2d(headers['user_float'][:, 3]), + tensor_2d(headers['user_float'][:, 4]), + tensor_2d(headers['user_float'][:, 5]), + tensor_2d(headers['user_float'][:, 6]), + tensor_2d(headers['user_float'][:, 7]), + tensor_2d(headers['user_int'][:, 0]), + tensor_2d(headers['user_int'][:, 1]), + tensor_2d(headers['user_int'][:, 2]), + tensor_2d(headers['user_int'][:, 3]), + tensor_2d(headers['user_int'][:, 4]), + tensor_2d(headers['user_int'][:, 5]), + tensor_2d(headers['user_int'][:, 6]), + tensor_2d(headers['user_int'][:, 7]), + ) + physiology_time_stamps = PhysiologyTimestamps( + tensor_2d(headers['physiology_time_stamp'][:, 0]).double(), + tensor_2d(headers['physiology_time_stamp'][:, 1]).double(), + tensor_2d(headers['physiology_time_stamp'][:, 2]).double(), + ) acq_info = cls( idx=acq_idx, - acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']), - active_channels=tensor_2d(headers['active_channels']), - available_channels=tensor_2d(headers['available_channels']), - center_sample=tensor_2d(headers['center_sample']), - channel_mask=tensor_2d(headers['channel_mask']), - discard_post=tensor_2d(headers['discard_post']), - discard_pre=tensor_2d(headers['discard_pre']), - encoding_space_ref=tensor_2d(headers['encoding_space_ref']), + acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']).double(), flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), - number_of_samples=tensor_2d(headers['number_of_samples']), orientation=Rotation.from_directions( spatialdimension_2d(headers['slice_dir']), spatialdimension_2d(headers['phase_dir']), spatialdimension_2d(headers['read_dir']), ), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), - physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above - user_float=tensor_2d(headers['user_float']), - user_int=tensor_2d(headers['user_int']), - version=tensor_2d(headers['version']), + user=user, + physiology_time_stamps=physiology_time_stamps, ) - return acq_info + + if additional_fields is None: + return acq_info + else: + additional_values = tuple(tensor_2d(headers[field]) for field in additional_fields) + return acq_info, additional_values diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 4b5df625..614d48fb 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -21,6 +21,7 @@ from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits +from mrpro.data.enums import AcqFlags from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape @@ -136,9 +137,12 @@ def from_file( kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) - acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) + acq_info, (k0_center, n_k0_tensor, discard_pre, discard_post) = AcqInfo.from_ismrmrd_acquisitions( + acquisitions, + additional_fields=('center_sample', 'number_of_samples', 'discard_pre', 'discard_post'), + ) - if len(torch.unique(acqinfo.idx.user5)) > 1: + if len(torch.unique(acq_info.idx.user5)) > 1: warnings.warn( 'The Siemens to ismrmrd converter currently (ab)uses ' 'the user 5 indices for storing the kspace center line number.\n' @@ -146,7 +150,7 @@ def from_file( stacklevel=1, ) - if len(torch.unique(acqinfo.idx.user6)) > 1: + if len(torch.unique(acq_info.idx.user6)) > 1: warnings.warn( 'The Siemens to ismrmrd converter currently (ab)uses ' 'the user 6 indices for storing the kspace center partition number.\n' @@ -157,7 +161,7 @@ def from_file( # Raises ValueError if required fields are missing in the header kheader = KHeader.from_ismrmrd( ismrmrd_header, - acqinfo, + acq_info, defaults={ 'datetime': modification_time, # use the modification time of the dataset as fallback 'trajectory': ktrajectory, @@ -171,9 +175,9 @@ def from_file( # (number_of_samples, center_sample) of (100, 20) (e.g. partial Fourier in the negative k0 direction) and # (100, 80) (e.g. partial Fourier in the positive k0 direction) then this should lead to encoding limits of # [min=0, max=159, center=80] - max_center_sample = int(torch.max(kheader.acq_info.center_sample)) - max_pos_k0_extend = int(torch.max(kheader.acq_info.number_of_samples - kheader.acq_info.center_sample)) - kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_pos_k0_extend - 1, max_center_sample) + max_center_sample = int(torch.max(k0_center)) + max_positive_k0_extend = int(torch.max(n_k0_tensor - k0_center)) + kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_positive_k0_extend - 1, max_center_sample) # Sort and reshape the kdata and the acquisistion info according to the indices. # within "other", the aquisistions are sorted in the order determined by KDIM_SORT_LABELS. @@ -238,7 +242,24 @@ def from_file( case KTrajectoryIsmrmrd(): ktrajectory_final = ktrajectory(acquisitions).sort_and_reshape(sort_idx, n_k2, n_k1) case KTrajectoryCalculator(): - ktrajectory_or_rawshape = ktrajectory(kheader) + reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool() + n_k0_unique = torch.unique(n_k0_tensor) + if len(n_k0_unique) > 1: + raise ValueError( + 'Trajectory can only be calculated for constant number of readout samples.\n' + f'Got unique values {list(n_k0_unique)}' + ) + ktrajectory_or_rawshape = ktrajectory( + n_k0=int(n_k0_unique[0]), + k0_center=k0_center, + k1_idx=kheader.acq_info.idx.k1, + k1_center=kheader.encoding_limits.k1.center, + k2_idx=kheader.acq_info.idx.k2, + k2_center=kheader.encoding_limits.k2.center, + reversed_readout_mask=reversed_readout_mask, + encoding_matrix=kheader.encoding_matrix, + ) + if isinstance(ktrajectory_or_rawshape, KTrajectoryRawShape): ktrajectory_final = ktrajectory_or_rawshape.sort_and_reshape(sort_idx, n_k2, n_k1) else: diff --git a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py index 555f56a3..592d123b 100644 --- a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py +++ b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py @@ -65,11 +65,5 @@ def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: # Adapt header parameters header = deepcopy(self.header) - header.acq_info.center_sample -= start_cropped_readout - header.acq_info.number_of_samples[:] = cropped_data.shape[-1] header.encoding_matrix.x = cropped_data.shape[-1] - - header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) - header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) - return type(self)(header, cropped_data, cropped_traj) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py index 1893d761..1061ae6a 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py @@ -4,52 +4,85 @@ import torch -from mrpro.data.enums import AcqFlags -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data.SpatialDimension import SpatialDimension class KTrajectoryCalculator(ABC): """Base class for k-space trajectories.""" @abstractmethod - def __call__(self, header: KHeader) -> KTrajectory | KTrajectoryRawShape: + def __call__( + self, + *, + n_k0: int, + k0_center: int | torch.Tensor, + k1_idx: torch.Tensor, + k1_center: int | torch.Tensor, + k2_idx: torch.Tensor, + k2_center: int | torch.Tensor, + encoding_matrix: SpatialDimension, + reversed_readout_mask: torch.Tensor | None = None, + ) -> KTrajectory | KTrajectoryRawShape: """Calculate the trajectory for given KHeader. The shapes of kz, ky and kx of the calculated trajectory must be broadcastable to (prod(all_other_dimensions), k2, k1, k0). + + Not all of the parameters will be used by all implementations. + + Parameters + ---------- + n_k0 + number of samples in k0 + k1_idx + indices of k1 + k2_idx + indices of k2 + k0_center + position of k-space center in k0 + k1_center + position of k-space center in k1 + k2_center + position of k-space center in k2 + reversed_readout_mask + boolean tensor indicating reversed redout + encoding_matrix + encoding matrix, describing the extend of the k-space coordinates + + + + Returns + ------- + Trajectory + """ - ... - def _kfreq(self, kheader: KHeader) -> torch.Tensor: + def _readout( + self, n_k0: int, k0_center: int | torch.Tensor, reversed_readout_mask: torch.Tensor | None + ) -> torch.Tensor: """Calculate the trajectory along one readout (k0 dimension). Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in readout + k0_center + position of k-space center in readout + reversed_readout_mask + boolean tensor indicating reversed readout, e.g bipolar readout Returns ------- - trajectory along ONE readout + trajectory along one readout - Raises - ------ - ValueError - Number of samples have to be the same for each readout """ - n_samples = torch.unique(kheader.acq_info.number_of_samples) - center_sample = kheader.acq_info.center_sample - if len(n_samples) > 1: - raise ValueError('Trajectory can only be calculated if each acquisition has the same number of samples') - n_k0 = int(n_samples.item()) - - # Data can be obtained with standard or reversed readout (e.g. bipolar readout). - k0 = torch.linspace(0, n_k0 - 1, n_k0, dtype=torch.float32) - center_sample + k0 = torch.linspace(0, n_k0 - 1, n_k0, dtype=torch.float32) - k0_center # Data can be obtained with standard or reversed readout (e.g. bipolar readout). - reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool() - k0[reversed_readout_mask, :] = torch.flip(k0[reversed_readout_mask, :], (-1,)) + if reversed_readout_mask is not None: + k0, reversed_readout_mask = torch.broadcast_tensors(k0, reversed_readout_mask) + k0[reversed_readout_mask] = torch.flip(k0[reversed_readout_mask], (-1,)) return k0 @@ -59,8 +92,9 @@ class DummyTrajectory(KTrajectoryCalculator): Shape will fit to all data. Only used as dummy for testing. """ - def __call__(self, header: KHeader) -> KTrajectory: # noqa: ARG002 + def __call__(self, **_) -> KTrajectory: """Calculate dummy trajectory.""" kx = torch.zeros(1, 1, 1, 1) - ky = kz = torch.zeros(1, 1, 1, 1) + ky = torch.zeros(1, 1, 1, 1) + kz = torch.zeros(1, 1, 1, 1) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 1b0742ee..efa6b0b9 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -11,24 +10,47 @@ class KTrajectoryCartesian(KTrajectoryCalculator): """Cartesian trajectory.""" - def __call__(self, kheader: KHeader) -> KTrajectory: + def __call__( + self, + *, + n_k0: int, + k0_center: int | torch.Tensor, + k1_idx: torch.Tensor, + k1_center: int | torch.Tensor, + k2_idx: torch.Tensor, + k2_center: int | torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate Cartesian trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + k2_center + position of k-space center in k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- Cartesian trajectory for given KHeader """ # K-space locations along readout lines - kx = self._kfreq(kheader) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) # Trajectory along phase and slice encoding - ky = (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32) - kz = (kheader.acq_info.idx.k2 - kheader.encoding_limits.k2.center).to(torch.float32) + ky = (k1_idx - k1_center).to(torch.float32) + kz = (k2_idx - k2_center).to(torch.float32) # Bring to correct dimensions ky = repeat(ky, '... k2 k1-> ... k2 k1 k0', k0=1) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py b/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py index 2b6aad36..be0a09c4 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py @@ -47,6 +47,10 @@ def __call__(self, acquisitions: Sequence[ismrmrd.Acquisition]) -> KTrajectoryRa kx=ktraj_mrd[..., 0], ) else: - ktraj = KTrajectoryRawShape(kz=ktraj_mrd[..., 2], ky=ktraj_mrd[..., 1], kx=ktraj_mrd[..., 0]) + ktraj = KTrajectoryRawShape( + kz=ktraj_mrd[..., 2], + ky=ktraj_mrd[..., 1], + kx=ktraj_mrd[..., 0], + ) return ktraj diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 7c843572..91dda9d1 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -6,8 +6,8 @@ import torch from einops import rearrange -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data.SpatialDimension import SpatialDimension from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -28,13 +28,21 @@ def __init__(self, seq_path: str | Path, repeat_detection_tolerance: None | floa self.seq_path = seq_path self.repeat_detection_tolerance = repeat_detection_tolerance - def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: + def __call__( + self, + *, + n_k0: int, + encoding_matrix: SpatialDimension, + **_, + ) -> KTrajectoryRawShape: """Calculate trajectory from given .seq file and header information. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + encoding_matrix + encoding matrix, describing the extend of the k-space coordinates Returns ------- @@ -48,19 +56,20 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: k_traj_adc_numpy, _, _, _, _ = seq.calculate_kspace() k_traj_adc = torch.tensor(k_traj_adc_numpy, dtype=torch.float32) - n_samples = kheader.acq_info.number_of_samples - n_samples = torch.unique(n_samples) - if len(n_samples) > 1: - raise ValueError('We currently only support constant number of samples') - n_k0 = int(n_samples.item()) - - def reshape_pulseq_traj(k_traj: torch.Tensor, encoding_size: int): - k_traj *= encoding_size / (2 * torch.max(torch.abs(k_traj))) + def reshape(k_traj: torch.Tensor, encoding_size: int) -> torch.Tensor: + max_value_range = 2 * torch.max(torch.abs(k_traj)) + if max_value_range > 1e-9 and encoding_size > 1: + k_traj = k_traj * encoding_size / max_value_range + else: + # If encoding matrix is 1, we force k_traj to be 0. We assume here that the values are + # numerical noise returned by pulseq, not real trajectory values + # even if pulseq returned some numerical noise. Also we avoid division by zero. + k_traj = torch.zeros_like(k_traj) return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0) # rearrange k-space trajectory to match MRpro convention - kx = reshape_pulseq_traj(k_traj_adc[0], kheader.encoding_matrix.x) - ky = reshape_pulseq_traj(k_traj_adc[1], kheader.encoding_matrix.y) - kz = reshape_pulseq_traj(k_traj_adc[2], kheader.encoding_matrix.z) + kx = reshape(k_traj_adc[0], encoding_matrix.x) + ky = reshape(k_traj_adc[1], encoding_matrix.y) + kz = reshape(k_traj_adc[2], encoding_matrix.z) return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py index 1616af90..c458a69f 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -22,27 +21,35 @@ def __init__(self, angle: float = torch.pi * 0.618034) -> None: super().__init__() self.angle: float = angle - def __call__(self, kheader: KHeader) -> KTrajectory: + def __call__( + self, + *, + n_k0: int, + k0_center: int | torch.Tensor, + k1_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate radial 2D trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- radial 2D trajectory for given KHeader """ - # K-space locations along readout lines - krad = self._kfreq(kheader) - - # Angles of readout lines - kang = repeat(kheader.acq_info.idx.k1 * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - - # K-space radial coordinates - kx = krad * torch.cos(kang) - ky = krad * torch.sin(kang) + radial = self._readout(n_k0=n_k0, k0_center=k0_center, reversed_readout_mask=reversed_readout_mask) + angle = repeat(k1_idx * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) + kx = radial * torch.cos(angle) + ky = radial * torch.sin(angle) kz = torch.zeros(1, 1, 1, 1) - return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py index 377a7ae9..86f45813 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -39,7 +38,7 @@ def __init__(self, angle: float, shift_between_rpe_lines: tuple | torch.Tensor = self.angle: float = angle self.shift_between_rpe_lines: torch.Tensor = torch.as_tensor(shift_between_rpe_lines) - def _apply_shifts_between_rpe_lines(self, krad: torch.Tensor, kang_idx: torch.Tensor) -> torch.Tensor: + def _apply_shifts_between_rpe_lines(self, k_radial: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """Shift radial phase encoding lines relative to each other. Example: shift_between_rpe_lines = [0, 0.5, 0.25, 0.75] leads to a shift of the 0th line by 0, @@ -56,80 +55,70 @@ def _apply_shifts_between_rpe_lines(self, krad: torch.Tensor, kang_idx: torch.Te Parameters ---------- - krad - k-space positions along each phase encoding line - kang_idx - indices of angles to be used for shift calculation - - References - ---------- - .. [PRI2010] Prieto C, Schaeffter T (2010) 3D undersampled golden-radial phase encoding - for DCE-MRA using inherently regularized iterative SENSE. MRM 64(2). https://doi.org/10.1002/mrm.22446 - """ - for ind, shift in enumerate(self.shift_between_rpe_lines): - curr_angle_idx = torch.nonzero( - torch.fmod(kang_idx, len(self.shift_between_rpe_lines)) == ind, - as_tuple=True, - ) - curr_krad = krad[curr_angle_idx] - - # Do not shift the k-space center - curr_krad += shift * (curr_krad != 0) - - krad[curr_angle_idx] = curr_krad - return krad - - def _kang(self, kheader: KHeader) -> torch.Tensor: - """Calculate the angles of the phase encoding lines. - - Parameters - ---------- - kheader - MR raw data header (KHeader) containing required meta data + k_radial + k-space positions along each phase encoding line, zo be shifted + idx + indices used for shift calculation Returns ------- - angles of phase encoding lines - """ - return repeat(kheader.acq_info.idx.k2 * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - - def _krad(self, kheader: KHeader) -> torch.Tensor: - """Calculate the k-space locations along the phase encoding lines. + shifted radial k-space positions - Parameters + References ---------- - kheader - MR raw data header (KHeader) containing required meta data - - Returns - ------- - k-space locations along the phase encoding lines + .. [PRI2010] Prieto C, Schaeffter T (2010) 3D undersampled golden-radial phase encoding + for DCE-MRA using inherently regularized iterative SENSE. MRM 64(2). https://doi.org/10.1002/mrm.22446 """ - krad = (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32) - krad = self._apply_shifts_between_rpe_lines(krad, kheader.acq_info.idx.k2) - return repeat(krad, '... k2 k1 -> ... k2 k1 k0', k0=1) + # do not shift k-space center + not_center = ~torch.isclose(k_radial, torch.tensor(0)) - def __call__(self, kheader: KHeader) -> KTrajectory: + for ind, shift in enumerate(self.shift_between_rpe_lines): + current_mask = (idx % len(self.shift_between_rpe_lines)) == ind + current_mask &= not_center + k_radial[current_mask] += shift + + return k_radial + + def __call__( + self, + *, + n_k0: int, + k0_center: int | torch.Tensor, + k1_idx: torch.Tensor, + k1_center: int | torch.Tensor, + k2_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate radial phase encoding trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- radial phase encoding trajectory for given KHeader """ - # Trajectory along readout - kx = self._kfreq(kheader) + angles = repeat(k2_idx * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - # Angles of phase encoding lines - kang = self._kang(kheader) + radial = (k1_idx - k1_center).to(torch.float32) + radial = self._apply_shifts_between_rpe_lines(radial, k2_idx) + radial = repeat(radial, '... k2 k1 -> ... k2 k1 k0', k0=1) - # K-space locations along phase encoding lines - krad = self._krad(kheader) + kz = radial * torch.sin(angles) + ky = radial * torch.cos(angles) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) - kz = krad * torch.sin(kang) - ky = krad * torch.cos(kang) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py index 8ab4abd9..2c34b7e6 100644 --- a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py @@ -4,90 +4,94 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader -from mrpro.data.traj_calculators.KTrajectoryRpe import KTrajectoryRpe +from mrpro.data.KTrajectory import KTrajectory +from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator +GOLDEN_RATIO = 0.5 * (5**0.5 + 1) -class KTrajectorySunflowerGoldenRpe(KTrajectoryRpe): + +class KTrajectorySunflowerGoldenRpe(KTrajectoryCalculator): """Radial phase encoding trajectory with a sunflower pattern.""" - def __init__(self, rad_us_factor: float = 1.0) -> None: + def __init__(self, radial_undersampling_factor: float = 1.0) -> None: """Initialize KTrajectorySunflowerGoldenRpe. Parameters ---------- - rad_us_factor + radial_undersampling_factor undersampling factor along radial phase encoding direction. """ - super().__init__(angle=torch.pi * 0.618034) - self.rad_us_factor: float = rad_us_factor + self.angle = torch.pi * 0.618034 + + if radial_undersampling_factor != 1: + raise NotImplementedError('Radial undersampling is not yet implemented') def _apply_sunflower_shift_between_rpe_lines( self, - krad: torch.Tensor, - kang: torch.Tensor, - kheader: KHeader, + radial: torch.Tensor, + angles: torch.Tensor, + k2_idx: torch.Tensor, ) -> torch.Tensor: """Shift radial phase encoding lines relative to each other. The shifts are applied to create a sunflower pattern of k-space points in the ky-kz phase encoding plane. - The applied shifts can lead to a scaling of the FOV. This scaling depends on the undersampling factor along the - radial phase encoding direction and is compensated for at the end. Parameters ---------- - krad - k-space positions along each phase encoding line - kang - angles of the radial phase encoding lines - kheader - MR raw data header (KHeader) containing required meta data + radial + position along radial direction + angles + angle of spokes + k2_idx + indices in k2 """ - kang = kang.flatten() - _, indices = np.unique(kang, return_index=True) + angles = angles.flatten() + _, indices = np.unique(angles, return_index=True) shift_idx = np.argsort(indices) - - # Apply sunflower shift - golden_ratio = 0.5 * (np.sqrt(5) + 1) for ind, shift in enumerate(shift_idx): - krad[kheader.acq_info.idx.k2 == ind] += ((shift * golden_ratio) % 1) - 0.5 + radial[k2_idx == ind] += ((shift * GOLDEN_RATIO) % 1) - 0.5 + return radial - # Set asym k-space point to 0 because this point was used to obtain a self-navigator signal. - krad[kheader.acq_info.idx.k1 == 0] = 0 - - return krad - - def _kang(self, kheader: KHeader) -> torch.Tensor: - """Calculate the angles of the phase encoding lines. + def __call__( + self, + *, + n_k0: int, + k0_center: int | torch.Tensor, + k1_idx: torch.Tensor, + k1_center: int | torch.Tensor, + k2_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: + """Calculate radial phase encoding trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- - angles of phase encoding lines + radial phase encoding trajectory for given KHeader """ - return repeat((kheader.acq_info.idx.k2 * self.angle) % torch.pi, '... k2 k1 -> ... k2 k1 k0', k0=1) + angles = repeat((k2_idx * self.angle) % torch.pi, '... k2 k1 -> ... k2 k1 k0', k0=1) + radial = repeat((k1_idx - k1_center).to(torch.float32), '... k2 k1 -> ... k2 k1 k0', k0=1) + radial = self._apply_sunflower_shift_between_rpe_lines(radial, angles, k2_idx) - def _krad(self, kheader: KHeader) -> torch.Tensor: - """Calculate the k-space locations along the phase encoding lines. + # Asymmetric k-space point is used to obtain a self-navigator signal, thus should be in k-space center + radial[k1_idx == 0] = 0 - Parameters - ---------- - kheader - MR raw data header (KHeader) containing required meta data - - Returns - ------- - k-space locations along the phase encoding lines - """ - kang = self._kang(kheader) - krad = repeat( - (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32), - '... k2 k1 -> ... k2 k1 k0', - k0=1, - ) - krad = self._apply_sunflower_shift_between_rpe_lines(krad, kang, kheader) - return krad + kz = radial * torch.sin(angles) + ky = radial * torch.cos(angles) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) + return KTrajectory(kz, ky, kx) diff --git a/tests/conftest.py b/tests/conftest.py index 30ae9c22..0918e862 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,7 +192,7 @@ def random_acq_info(random_acquisition): return acq_info -@pytest.fixture(params=({'seed': 0, 'n_other': 10, 'n_k2': 40, 'n_k1': 20},)) +@pytest.fixture(params=({'seed': 0, 'n_other': 10, 'n_k2': 40, 'n_k1': 20, 'n_k0': 64, 'n_coils': 2},)) def random_kheader_shape(request, random_acquisition, random_full_ismrmrd_header): """Random (not necessarily valid) KHeader with defined shape.""" # Get dimensions @@ -206,14 +206,15 @@ def random_kheader_shape(request, random_acquisition, random_full_ismrmrd_header # Generate acquisitions random_acq_info = AcqInfo.from_ismrmrd_acquisitions([random_acquisition for _ in range(n_k1 * n_k2 * n_other)]) - n_k0 = int(random_acq_info.number_of_samples[0]) - n_coils = int(random_acq_info.active_channels[0]) # Generate trajectory + n_k0 = int(request.param['n_k0']) ktraj = [generate_random_trajectory(generator, shape=(n_k0, 2)) for _ in range(n_k1 * n_k2 * n_other)] # Put it all together to a KHeader object kheader = KHeader.from_ismrmrd(random_full_ismrmrd_header, acq_info=random_acq_info, defaults={'trajectory': ktraj}) + n_coils = int(request.param['n_coils']) + return kheader, n_other, n_coils, n_k2, n_k1, n_k0 diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index fa3e4ebd..386ef60b 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -260,8 +260,8 @@ def test_KData_to_complex128_header(ismrmrd_cart): """Change KData dtype complex128: test header""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) kdata_complex128 = kdata.to(dtype=torch.complex128) - assert kdata_complex128.header.acq_info.user_float.dtype == torch.float64 - assert kdata_complex128.header.acq_info.user_int.dtype == torch.int32 + assert kdata_complex128.header.acq_info.user.float1.dtype == torch.float64 + assert kdata_complex128.header.acq_info.user.int1.dtype == torch.int32 @pytest.mark.cuda @@ -285,7 +285,7 @@ def test_KData_cuda(ismrmrd_cart): assert kdata_cuda.traj.kz.is_cuda assert kdata_cuda.traj.ky.is_cuda assert kdata_cuda.traj.kx.is_cuda - assert kdata_cuda.header.acq_info.user_int.is_cuda + assert kdata_cuda.header.acq_info.user.int1.is_cuda assert kdata_cuda.device == torch.device(torch.cuda.current_device()) assert kdata_cuda.header.acq_info.device == torch.device(torch.cuda.current_device()) assert kdata_cuda.is_cuda @@ -301,7 +301,7 @@ def test_KData_cpu(ismrmrd_cart): assert kdata_cpu.traj.kz.is_cpu assert kdata_cpu.traj.ky.is_cpu assert kdata_cpu.traj.kx.is_cpu - assert kdata_cpu.header.acq_info.user_int.is_cpu + assert kdata_cpu.header.acq_info.user.int1.is_cpu assert kdata_cpu.device == torch.device('cpu') assert kdata_cpu.header.acq_info.device == torch.device('cpu') diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 7ddcb30a..69d4e364 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -1,6 +1,5 @@ """Tests for KTrajectory Calculator classes.""" -import numpy as np import pytest import torch from einops import repeat @@ -18,202 +17,127 @@ from tests.data import IsmrmrdRawTestData, PulseqRadialTestSeq -@pytest.fixture -def valid_rad2d_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for radial 2D trajectories.""" - # K-space dimensions +def test_KTrajectoryRadial2D(): + """Test shapes returned by KTrajectoryRadial2D.""" + n_k0 = 256 n_k1 = 10 - n_k2 = 1 - - # List of k1 indices in the shape - idx_k1 = repeat(torch.arange(n_k1, dtype=torch.int32), 'k1 -> other k2 k1', other=1, k2=1) - - # Set parameters for radial 2D trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - # This is only needed for Pulseq trajectory calculation - monkeypatch.setattr(random_kheader.encoding_matrix, 'x', n_k0) - monkeypatch.setattr(random_kheader.encoding_matrix, 'y', n_k0) # square encoding in kx-ky plane - monkeypatch.setattr(random_kheader.encoding_matrix, 'z', n_k2) - - return random_kheader - - -def radial2D_traj_shape(valid_rad2d_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[2] - n_k2 = 1 - n_other = 1 - return ( - torch.Size([n_other, 1, 1, 1]), - torch.Size([n_other, n_k2, n_k1, n_k0]), - torch.Size([n_other, n_k2, n_k1, n_k0]), + trajectory_calculator = KTrajectoryRadial2D() + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], ) + assert trajectory.kz.shape == (1, 1, 1, 1) + assert trajectory.ky.shape == (1, 1, n_k1, n_k0) + assert trajectory.kx.shape == (1, 1, n_k1, n_k0) -def test_KTrajectoryRadial2D_golden(valid_rad2d_kheader): - """Calculate Radial 2D trajectory with golden angle.""" - trajectory_calculator = KTrajectoryRadial2D(angle=torch.pi * 0.618034) - trajectory = trajectory_calculator(valid_rad2d_kheader) - valid_shape = radial2D_traj_shape(valid_rad2d_kheader) - assert trajectory.kx.shape == valid_shape[2] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kz.shape == valid_shape[0] - - -@pytest.fixture -def valid_rpe_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for RPE trajectories.""" - # K-space dimensions - n_k0 = 200 +def test_KTrajectoryRpe(): + """Test shapes returned by KTrajectoryRpe""" + n_k0 = 100 n_k1 = 20 n_k2 = 10 - # List of k1 and k2 indices in the shape (other, k2, k1) - k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) - k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) - idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = torch.reshape(idx_k1, (1, n_k2, n_k1)) - idx_k2 = torch.reshape(idx_k2, (1, n_k2, n_k1)) - - # Set parameters for RPE trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) - return random_kheader - - -def rpe_traj_shape(valid_rpe_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[1] - n_other = 1 - return ( - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, 1, 1, n_k0]), + trajectory_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034) + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k2)[None, :, None, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k1)[None, None, :, None], ) + assert trajectory.kz.shape == (1, n_k2, n_k1, 1) + assert trajectory.ky.shape == (1, n_k2, n_k1, 1) + assert trajectory.kx.shape == (1, 1, 1, n_k0) -def test_KTrajectoryRpe_golden(valid_rpe_kheader): - """Calculate RPE trajectory with golden angle.""" - trajectory_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034) - trajectory = trajectory_calculator(valid_rpe_kheader) - valid_shape = rpe_traj_shape(valid_rpe_kheader) - assert trajectory.kz.shape == valid_shape[0] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kx.shape == valid_shape[2] - - -def test_KTrajectoryRpe_uniform(valid_rpe_kheader): - """Calculate RPE trajectory with uniform angle.""" - n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[1] - trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_rpe_lines, shift_between_rpe_lines=torch.tensor([0])) - trajectory1 = trajectory1_calculator(valid_rpe_kheader) +def test_KTrajectoryRpe_angle(): + """Test that every second line matches the first half of lines of a trajectory with double the angular gap.""" + n_k0 = 100 + n_k1 = 20 + n_k2 = 10 + k1_idx = torch.arange(n_k1)[None, None, :, None] + k2_idx = torch.arange(n_k2)[None, :, None, None] + trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_k1, shift_between_rpe_lines=(0,)) + trajectory1 = trajectory1_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=k1_idx, + k1_center=n_k1 // 2, + k2_idx=k2_idx, + ) # Calculate trajectory with half the angular gap such that every second line should be the same as above trajectory2_calculator = KTrajectoryRpe( - angle=torch.pi / (2 * n_rpe_lines), - shift_between_rpe_lines=torch.tensor([0]), + angle=torch.pi / (2 * n_k1), + shift_between_rpe_lines=torch.tensor([0, 0, 0, 0]), ) - trajectory2 = trajectory2_calculator(valid_rpe_kheader) - - torch.testing.assert_close(trajectory1.kx[:, : n_rpe_lines // 2, :, :], trajectory2.kx[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.ky[:, : n_rpe_lines // 2, :, :], trajectory2.ky[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.kz[:, : n_rpe_lines // 2, :, :], trajectory2.kz[:, ::2, :, :]) - - -def test_KTrajectoryRpe_shift(valid_rpe_kheader): - """Evaluate radial shifts for RPE trajectory.""" - trajectory1_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034, shift_between_rpe_lines=torch.tensor([0.25])) - trajectory1 = trajectory1_calculator(valid_rpe_kheader) - trajectory2_calculator = KTrajectoryRpe( - angle=torch.pi * 0.618034, - shift_between_rpe_lines=torch.tensor([0.25, 0.25, 0.25, 0.25]), + trajectory2 = trajectory2_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=k1_idx, + k1_center=n_k1 // 2, + k2_idx=k2_idx, ) - trajectory2 = trajectory2_calculator(valid_rpe_kheader) - torch.testing.assert_close(trajectory1.as_tensor(), trajectory2.as_tensor()) - -def test_KTrajectorySunflowerGoldenRpe(valid_rpe_kheader): - """Calculate RPE Sunflower trajectory.""" - trajectory_calculator = KTrajectorySunflowerGoldenRpe(rad_us_factor=2) - trajectory = trajectory_calculator(valid_rpe_kheader) - assert trajectory.broadcasted_shape == np.broadcast_shapes(*rpe_traj_shape(valid_rpe_kheader)) + torch.testing.assert_close(trajectory1.kx[:, : n_k1 // 2, :, :], trajectory2.kx[:, ::2, :, :]) + torch.testing.assert_close(trajectory1.ky[:, : n_k1 // 2, :, :], trajectory2.ky[:, ::2, :, :]) + torch.testing.assert_close(trajectory1.kz[:, : n_k1 // 2, :, :], trajectory2.kz[:, ::2, :, :]) -@pytest.fixture -def valid_cartesian_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for Cartesian trajectories.""" - # K-space dimensions - n_k0 = 200 +def test_KTrajectorySunflowerGoldenRpe(): + """Test shape returned by KTrajectorySunflowerGoldenRpe""" + n_k0 = 100 n_k1 = 20 n_k2 = 10 - n_other = 2 - - # List of k1 and k2 indices in the shape (other, k2, k1) - k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) - k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) - idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - - # Set parameters for Cartesian trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) - monkeypatch.setattr(random_kheader.encoding_limits.k2, 'center', int(n_k2 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k2, 'max', int(n_k2 - 1)) - return random_kheader - - -def cartesian_traj_shape(valid_cartesian_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1] - n_other = 1 # trajectory along other is the same - return (torch.Size([n_other, n_k2, 1, 1]), torch.Size([n_other, 1, n_k1, 1]), torch.Size([n_other, 1, 1, n_k0])) + k1_idx = torch.arange(n_k1)[None, None, :, None] + k2_idx = torch.arange(n_k2)[None, :, None, None] + trajectory_calculator = KTrajectorySunflowerGoldenRpe() + trajectory = trajectory_calculator( + n_k0=n_k0, k0_center=n_k0 // 2, k1_idx=k1_idx, k1_center=n_k1 // 2, k2_idx=k2_idx + ) + assert trajectory.broadcasted_shape == (1, n_k2, n_k1, n_k0) def test_KTrajectoryCartesian(valid_cartesian_kheader): """Calculate Cartesian trajectory.""" + n_k0 = 30 + n_k1 = 20 + n_k2 = 10 trajectory_calculator = KTrajectoryCartesian() - trajectory = trajectory_calculator(valid_cartesian_kheader) - valid_shape = cartesian_traj_shape(valid_cartesian_kheader) - assert trajectory.kz.shape == valid_shape[0] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kx.shape == valid_shape[2] - - -@pytest.fixture -def valid_cartesian_kheader_bipolar(monkeypatch, valid_cartesian_kheader): - """Set readout of other==1 to reversed.""" - acq_info_flags = valid_cartesian_kheader.acq_info.flags - acq_info_flags[1, ...] = AcqFlags.ACQ_IS_REVERSE.value - monkeypatch.setattr(valid_cartesian_kheader.acq_info, 'flags', acq_info_flags) - return valid_cartesian_kheader + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k2)[None, :, None, None], + k2_center=n_k2 // 2, + ) + assert trajectory.kz.shape == (1, n_k2, 1, 1) + assert trajectory.ky.shape == (1, 1, n_k1, 1) + assert trajectory.kx.shape == (1, 1, 1, n_k0) def test_KTrajectoryCartesian_bipolar(valid_cartesian_kheader_bipolar): - """Calculate Cartesian trajectory with bipolar readout.""" + """Verify that the readout for the second part of a bipolar readout is reversed""" + trajectory_calculator = KTrajectoryCartesian() + n_k0 = 30 + n_k1 = 20 + n_k2 = 10 + reversed_readout_mask = torch.zeros(n_k1, 1, dtype=torch.bool) + reversed_readout_mask[1] = True trajectory_calculator = KTrajectoryCartesian() - trajectory = trajectory_calculator(valid_cartesian_kheader_bipolar) - # Verify that the readout for the second part of the bipolar readout is reversed - torch.testing.assert_close(trajectory.kx[0, ...], torch.flip(trajectory.kx[1, ...], dims=(-1,))) + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k2)[None, :, None, None], + k2_center=n_k2 // 2, + reversed_readout_mask=reversed_readout_mask, + ) + torch.testing.assert_close(trajectory.kx[..., 0, :], torch.flip(trajectory.kx[..., 1, :], dims=(-1,))) @pytest.fixture(scope='session') @@ -252,13 +176,12 @@ def pulseq_example_rad_seq(tmp_path_factory): return seq -def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_rad2d_kheader): +def test_KTrajectoryPulseq(pulseq_example_rad_seq, valid_rad2d_kheader): """Test pulseq File reader with valid seq File.""" - # TODO: Test with valid header # TODO: Test with invalid seq file trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename) - trajectory = trajectory_calculator(kheader=valid_rad2d_kheader) + trajectory = trajectory_calculator(n_k0=n_k0, encoding_matrix=encoding_matrix) kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0) kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test))) @@ -268,3 +191,50 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_ torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3) torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3) + + +@pytest.fixture +def valid_cartesian_kheader(monkeypatch, random_kheader): + """KHeader with all necessary parameters for Cartesian trajectories.""" + # K-space dimensions + n_k0 = 200 + n_k1 = 20 + n_k2 = 10 + n_other = 2 + + # List of k1 and k2 indices in the shape (other, k2, k1) + k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) + k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) + idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') + idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + + # Set parameters for Cartesian trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) + monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) + monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) + monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) + monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) + monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) + monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) + monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) + monkeypatch.setattr(random_kheader.encoding_limits.k2, 'center', int(n_k2 // 2)) + monkeypatch.setattr(random_kheader.encoding_limits.k2, 'max', int(n_k2 - 1)) + return random_kheader + + +def cartesian_traj_shape(valid_cartesian_kheader): + """Expected shape of trajectory based on KHeader.""" + n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0] + n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2] + n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1] + n_other = 1 # trajectory along other is the same + return + + +@pytest.fixture +def valid_cartesian_kheader_bipolar(monkeypatch, valid_cartesian_kheader): + """Set readout of other==1 to reversed.""" + acq_info_flags = valid_cartesian_kheader.acq_info.flags + acq_info_flags[1, ...] = AcqFlags.ACQ_IS_REVERSE.value + monkeypatch.setattr(valid_cartesian_kheader.acq_info, 'flags', acq_info_flags) + return valid_cartesian_kheader