Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move functions from mixins into KData #559

Merged
merged 10 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/prewhiten_kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from einops import einsum, parse_shape, rearrange

from mrpro.data._kdata.KData import KData
from mrpro.data.KData import KData
from mrpro.data.KNoise import KNoise


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from collections.abc import Callable

from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction
from mrpro.data._kdata.KData import KData
from mrpro.data.CsmData import CsmData
from mrpro.data.DcfData import DcfData
from mrpro.data.IData import IData
from mrpro.data.KData import KData
from mrpro.data.KNoise import KNoise
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.LinearOperator import LinearOperator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import (
RegularizedIterativeSENSEReconstruction,
)
from mrpro.data._kdata.KData import KData
from mrpro.data.CsmData import CsmData
from mrpro.data.DcfData import DcfData
from mrpro.data.KData import KData
from mrpro.data.KNoise import KNoise
from mrpro.operators.LinearOperator import LinearOperator

Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/reconstruction/Reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing_extensions import Self

from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace
from mrpro.data._kdata.KData import KData
from mrpro.data.CsmData import CsmData
from mrpro.data.DcfData import DcfData
from mrpro.data.IData import IData
from mrpro.data.KData import KData
from mrpro.data.KNoise import KNoise
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.LinearOperator import LinearOperator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from mrpro.algorithms.optimizers.cg import cg
from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace
from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction
from mrpro.data._kdata.KData import KData
from mrpro.data.CsmData import CsmData
from mrpro.data.DcfData import DcfData
from mrpro.data.IData import IData
from mrpro.data.KData import KData
from mrpro.data.KNoise import KNoise
from mrpro.operators.IdentityOp import IdentityOp
from mrpro.operators.LinearOperator import LinearOperator
Expand Down
300 changes: 293 additions & 7 deletions src/mrpro/data/_kdata/KData.py → src/mrpro/data/KData.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
"""MR raw data / k-space data class."""

import copy
import dataclasses
import datetime
import warnings
from collections.abc import Callable, Sequence
from pathlib import Path
from types import EllipsisType
from typing import Literal, cast

import h5py
import ismrmrd
import numpy as np
import torch
from einops import rearrange
from typing_extensions import Self
from einops import rearrange, repeat
from typing_extensions import Self, TypeVar

from mrpro.data._kdata.KDataRearrangeMixin import KDataRearrangeMixin
from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin
from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin
from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin
from mrpro.data.acq_filters import has_n_coils, is_image_acquisition
from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields
from mrpro.data.EncodingLimits import Limits
Expand All @@ -29,6 +27,8 @@
from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator
from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd

RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation)

KDIM_SORT_LABELS = (
'k1',
'k2',
Expand Down Expand Up @@ -63,7 +63,9 @@


@dataclasses.dataclass(slots=True, frozen=True)
class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin, MoveDataMixin):
class KData(
MoveDataMixin,
):
"""MR raw data / k-space data class."""

header: KHeader
Expand Down Expand Up @@ -366,3 +368,287 @@ def compress_coils(
).permute(*np.argsort(permute_order))

return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone())

def rearrange_k2_k1_into_k1(self: Self) -> Self:
"""Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...).

Parameters
----------
kdata
K-space data (other coils k2 k1 k0)

Returns
-------
K-space data (other coils 1 (k2 k1) k0)
"""
# Rearrange data
kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0')

# Rearrange trajectory
ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0')

# Create new header with correct shape
kheader = copy.deepcopy(self.header)

# Update shape of acquisition info index
kheader.acq_info.apply_(
lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...')
)

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def remove_readout_os(self: Self) -> Self:
"""Remove any oversampling along the readout (k0) direction [GAD]_.

Returns a copy of the data.

Parameters
----------
kdata
K-space data

Returns
-------
Copy of K-space data with oversampling removed.

Raises
------
ValueError
If the recon matrix along x is larger than the encoding matrix along x.

References
----------
.. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python
"""
from mrpro.operators.FastFourierOp import FastFourierOp
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved

# Ratio of k0/x between encoded and recon space
x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x
if x_ratio == 1:
# If the encoded and recon space is the same we don't have to do anything
return self
elif x_ratio > 1:
raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.')

# Starting and end point of image after removing oversampling
start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2
end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x

def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor:
# returns a cropped copy
return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone()

# Transform to image space along readout, crop to reconstruction matrix size and transform back
fourier_k0_op = FastFourierOp(dim=(-1,))
(cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data)))

# Adapt trajectory
ks = [self.traj.kz, self.traj.ky, self.traj.kx]
# only cropped ks that are not broadcasted/singleton along k0
cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks]
cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2])

# Adapt header parameters
header = copy.deepcopy(self.header)
header.acq_info.center_sample -= start_cropped_readout
header.acq_info.number_of_samples[:] = cropped_data.shape[-1]
header.encoding_matrix.x = cropped_data.shape[-1]

header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32)
header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32)

return type(self)(header, cropped_data, cropped_traj)

def select_other_subset(
self: Self,
subset_idx: torch.Tensor,
subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Select a subset from the other dimension of KData.

Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
subset_idx
Index which elements of the other subset to use, e.g. phase 0,1,2 and 5
subset_label
Name of the other label, e.g. phase

Returns
-------
K-space data (other_subset coils k2 k1 k0)

Raises
------
ValueError
If the subset indices are not available in the data
"""
# Make a copy such that the original kdata.header remains the same
kheader = copy.deepcopy(self.header)
ktraj = self.traj.as_tensor()

# Verify that the subset_idx is available
label_idx = getattr(kheader.acq_info.idx, subset_label)
if not all(el in torch.unique(label_idx) for el in subset_idx):
raise ValueError('Subset indices are outside of the available index range')

# Find subset index in acq_info index
other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0)

# Adapt header
kheader.acq_info.apply_(
lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field
)

# Select data
kdat = self.data[other_idx, ...]

# Select ktraj
if ktraj.shape[1] > 1:
ktraj = ktraj[:, other_idx, ...]

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def _split_k2_or_k1_into_other(
self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
split_dir: Literal['k2', 'k1'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.

Parameters
----------
split_idx
2D index describing the k2 or k1 points in each block to be moved to the other dimension
(other_split, k1_per_split) or (other_split, k2_per_split)
other_label
Label of other dimension, e.g. repetition, phase
split_dir
Dimension to split, either 'k1' or 'k2'

Returns
-------
K-space data with new shape
((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0)

Raises
------
ValueError
Already existing "other_label" can only be of length 1
"""
# Number of other
n_other = split_idx.shape[0]

# Verify that the specified label of the other dimension is unused
if getattr(self.header.encoding_limits, other_label).length > 1:
raise ValueError(f'{other_label} is already used to encode different parts of the scan.')

# Set-up splitting
if split_dir == 'k1':
# Split along k1 dimensions
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, :, split_idx, :]

def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
# cast due to https://github.com/python/mypy/issues/10817
return cast(RotationOrTensor, acq_info[:, :, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0'
rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0'
rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...'

elif split_dir == 'k2':
# Split along k2 dimensions
def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor:
return dat_traj[:, :, split_idx, :, :]

def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor:
return cast(RotationOrTensor, acq_info[:, split_idx, ...])

# Rearrange other_split and k1 dimension
rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0'
rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0'
rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...'

else:
raise ValueError('split_dir has to be "k1" or "k2"')

# Split data
kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data)

# First we need to make sure the other dimension is the same as data then we can split the trajectory
ktraj = self.traj.as_tensor()
# Verify that other dimension of trajectory is 1 or matches data
if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]:
raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})')
elif ktraj.shape[1] == 1 and self.data.shape[0] > 1:
ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0])
ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj)

# Create new header with correct shape
kheader = self.header.clone()

# Update shape of acquisition info index
kheader.acq_info.apply_(
lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info)
if isinstance(field, Rotation | torch.Tensor)
else field
)

# Update other label limits and acquisition info
setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0))

# acq_info for new other dimensions
acq_info_other_split = repeat(
torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2]
)
setattr(kheader.acq_info.idx, other_label, acq_info_other_split)

return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj))

def split_k1_into_other(
self: Self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.

Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
split_idx
2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split)
other_label
Label of other dimension, e.g. repetition, phase

Returns
-------
K-space data with new shape ((other other_split) coils k2 k1_per_split k0)
"""
return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1')

def split_k2_into_other(
self: Self,
split_idx: torch.Tensor,
other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'],
) -> Self:
"""Based on an index tensor, split the data in e.g. phases.

Parameters
----------
kdata
K-space data (other coils k2 k1 k0)
split_idx
2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split)
other_label
Label of other dimension, e.g. repetition, phase

Returns
-------
K-space data with new shape ((other other_split) coils k2_per_split k1 k0)
"""
return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2')
Loading
Loading