Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 5909 add PadOrCropListDataCollate #5933

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
SpatialCrop,
SpatialPad,
)
from .croppad.batch import PadListDataCollate
from .croppad.batch import PadListDataCollate, PadOrCropListDataCollate
from .croppad.dictionary import (
BorderPadd,
BorderPadD,
Expand Down
120 changes: 118 additions & 2 deletions monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from monai.data.utils import list_data_collate
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.utils.enums import Method, PytorchPadMode, TraceKeys
from monai.utils.enums import Method, PytorchPadMode, TraceKeys, DataCollateMode

__all__ = ["PadListDataCollate"]
__all__ = ["PadListDataCollate", "PadOrCropListDataCollate"]


def replace_element(to_replace, batch, idx, key_or_idx):
Expand Down Expand Up @@ -136,3 +136,119 @@ def inverse(data: dict) -> dict[Hashable, np.ndarray]:
with cropping.trace_transform(False):
d[key] = cropping(d[key]) # fallback to image size
return d

class PadOrCropListDataCollate(InvertibleTransform):
"""
This class enhances `PadListDataCollate` via supporting pad (to maximal sizes), crop (to minimal sizes) and
resize (by pad or crop) to specified sizes.
This transform is useful if some of the applied transforms generate batch data of
different sizes.

This can be used on both list and dictionary data.
Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms
if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms.

Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`.
This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the
`inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can
pass the inverse through multiprocessing.

Args:
mode: available modes: {``"pad"``, ``"crop"``, ``"resize"``}.
spatial_size: the spatial size of output data after padding or crop.
If has non-positive values, the corresponding size of input image will be used (no padding).
pad_method: padding method (see :py:class:`monai.transforms.SpatialPad`)
pad_mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

"""

def __init__(
self,
mode: str = DataCollateMode.PAD,
spatial_size: Sequence[int] | int = -1,
pad_method: str = Method.SYMMETRIC,
pad_mode: str = PytorchPadMode.CONSTANT,
**pad_kwargs,
) -> None:
self.mode = mode
if self.mode == DataCollateMode.RESIZE:
self.resizer = ResizeWithPadOrCrop(spatial_size=spatial_size, method=pad_method, mode=pad_mode, **pad_kwargs)
elif self.mode == DataCollateMode.PAD:
self.pad_method = pad_method
self.pad_mode = pad_mode
self.pad_kwargs = pad_kwargs
elif self.mode != DataCollateMode.CROP:
raise ValueError(f"mode should be 'pad', 'crop' or 'reize', got {self.mode}.")

def __call__(self, batch: Any):
"""
Args:
batch: batch of data to pad-collate
"""
# data is either list of dicts or list of lists
is_list_of_dicts = isinstance(batch[0], dict)
# loop over items inside of each element in a batch
batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(len(batch[0]))
for key_or_idx in batch_item:
shapes = []
for elem in batch:
if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
break
shapes.append(elem[key_or_idx].shape[1:])
# len > 0 if objects were arrays, else skip as no padding to be done
if not shapes:
continue
if self.mode == DataCollateMode.RESIZE:
transform = self.resizer
else:
# calculate max and min size of each dimension
max_shape, min_shape = np.array(shapes).max(axis=0), np.array(shapes).min(axis=0)
# If all same size, skip
if np.all(min_shape == max_shape):
continue
if self.mode == DataCollateMode.PAD:
transform = SpatialPad(spatial_size=max_shape, method=self.pad_method, mode=self.pad_mode, **self.pad_kwargs)
else:
transform = CenterSpatialCrop(roi_size=min_shape)
for idx, batch_i in enumerate(batch):
orig_size = batch_i[key_or_idx].shape[1:]
to_replace = transform(batch_i[key_or_idx])
batch = replace_element(to_replace, batch, idx, key_or_idx)

# If we have a dictionary of data, append to list
# transform info is re-added with self.push_transform to ensure one info dict per transform.
if is_list_of_dicts:
self.push_transform(
batch[idx],
key_or_idx,
orig_size=orig_size,
extra_info=self.pop_transform(batch[idx], key_or_idx, check=False),
)

# After padding, use default list collator
return list_data_collate(batch)

@staticmethod
def inverse(data: dict) -> dict[Hashable, np.ndarray]:
if not isinstance(data, Mapping):
raise RuntimeError("Inverse can only currently be applied on dictionaries.")

d = dict(data)
for key in d:
transforms = None
if isinstance(d[key], MetaTensor):
transforms = d[key].applied_operations
else:
transform_key = InvertibleTransform.trace_key(key)
if transform_key in d:
transforms = d[transform_key]
if not transforms or not isinstance(transforms[-1], dict):
continue
if transforms[-1].get(TraceKeys.CLASS_NAME) == PadOrCropListDataCollate.__name__:
xform = transforms.pop()
cropping = CenterSpatialCrop(xform.get(TraceKeys.ORIG_SIZE, -1))
with cropping.trace_transform(False):
d[key] = cropping(d[key]) # fallback to image size
return d
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
TraceKeys,
TransformBackends,
UpsampleMode,
DataCollateMode,
Weight,
WSIPatchKeys,
)
Expand Down
11 changes: 11 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"SplineMode",
"InterpolateMode",
"UpsampleMode",
"DataCollateMode",
"BlendMode",
"PytorchPadMode",
"NdimageMode",
Expand Down Expand Up @@ -176,6 +177,16 @@ class UpsampleMode(StrEnum):
PIXELSHUFFLE = "pixelshuffle"


class DataCollateMode(StrEnum):
"""
See also: :py:class:`monai.transforms.PadOrCropListDataCollate`
"""

PAD = "pad"
CROP = "crop"
RESIZE = "resize"


class BlendMode(StrEnum):
"""
See also: :py:class:`monai.data.utils.compute_importance_map`
Expand Down