diff --git a/docs/images/lazy_resampling_apply_pending_example.svg b/docs/images/lazy_resampling_apply_pending_example.svg new file mode 100644 index 0000000000..2e23a10e82 --- /dev/null +++ b/docs/images/lazy_resampling_apply_pending_example.svg @@ -0,0 +1 @@ + diff --git a/docs/images/lazy_resampling_homogeneous_matrices.svg b/docs/images/lazy_resampling_homogeneous_matrices.svg new file mode 100644 index 0000000000..2a20b120da --- /dev/null +++ b/docs/images/lazy_resampling_homogeneous_matrices.svg @@ -0,0 +1 @@ + diff --git a/docs/images/lazy_resampling_lazy_example_1.svg b/docs/images/lazy_resampling_lazy_example_1.svg new file mode 100644 index 0000000000..6554635809 --- /dev/null +++ b/docs/images/lazy_resampling_lazy_example_1.svg @@ -0,0 +1 @@ + diff --git a/docs/images/lazy_resampling_none_example.svg b/docs/images/lazy_resampling_none_example.svg new file mode 100644 index 0000000000..ca83fa5449 --- /dev/null +++ b/docs/images/lazy_resampling_none_example.svg @@ -0,0 +1 @@ + diff --git a/docs/images/lazy_resampling_trad_example_1.svg b/docs/images/lazy_resampling_trad_example_1.svg new file mode 100644 index 0000000000..5d29bb08f2 --- /dev/null +++ b/docs/images/lazy_resampling_trad_example_1.svg @@ -0,0 +1 @@ + diff --git a/docs/source/index.rst b/docs/source/index.rst index 9f8c3cb7ec..2af6c6f6f9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -69,6 +69,7 @@ Technical documentation is available at `docs.monai.io `_ :caption: Specifications bundle_intro + lazy_resampling Model Zoo --------- diff --git a/docs/source/lazy_resampling.rst b/docs/source/lazy_resampling.rst new file mode 100644 index 0000000000..7b809965f3 --- /dev/null +++ b/docs/source/lazy_resampling.rst @@ -0,0 +1,273 @@ +.. _lazy_resampling: + +:github_url: https://github.com/Project-MONAI/MONAI + +Lazy Resampling +=============== + +.. toctree:: + :maxdepth: 2 + +Introduction +^^^^^^^^^^^^ + +Lazy Resampling is a new feature introduced in MONAI 1.2. This feature is still experimental at this time and it is +possible that behaviour and APIs will change in upcoming releases. + +Lazy resampling reworks the way that preprocessing is performed. It improves upon standard preprocessing pipelines and +can provide significant benefits over traditional preprocessing. It can improve: +* pipeline execution time +* pipeline memory usage in CPU or GPU +* image and segmentation quality by reducing incidental noise and artifacts caused by resampling + +The way it does this is by adopting the methods used in computer graphics pipelines, in which transformations to objects +in a scene are modified by composing together a sequence of "homogeneous matrices". + +Rather than each transform being executed in isolation, potentially requiring the data to be resampled to make a new +tensor, transforms whose operations can be described in terms of homogeneous transforms do not execute their transforms +immediately. Instead, they create a "pending operation", which is added to a list of operations that will be fused +together and carried out at the point that they are required. + + +How Lazy Resampling changes preprocessing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to understand the difference between traditional pipelines and lazy pipelines, it is best to look at an example +pipeline and the differences between their execution strategies: + + +Traditional execution ++++++++++++++++++++++ + +With traditional resampling, found both in MONAI and many other preprocessing libraries, you typically define a sequence +of transforms and pass them to a ``Compose`` object, such as :class:`monai.transforms.compose.Compose`. + +Example:: + + transforms = [ + Spacingd(keys=["img", "seg"], ...), + Orientationd(keys=["img", "seg"], ...), + RandSpatialCropd(keys=["img", "seg"], ...), + RandRotate90d(keys=["img", "seg"], ...), + RandRotated(keys=["img", "seg"], ...), + RandZoomd(keys=["img", "seg"], ...), + RandGaussianNoised(keys="img", ...), + ] + pipeline = Compose(transforms) + + # elsewhere this will be called many times (such as in a Dataset instance) + outputs = pipeline(inputs) + + +The following will then happen when we call ``pipeline(inputs)``: + +1. ``Spacingd`` is called and interpolates the data samples +2. ``Orientationd`` permutes the data samples so that their spatial dimensions are reorganised +3. ``RandSpatialCropd`` crops a random patch of the data samples, throwing away the rest of the data in the process +4. ``RandRotate90d`` has a chance of performing a tensor-based rotation of the data samples +5. ``RandRotated`` has a chance of performing a full resample of the data samples +6. ``RandZoomd`` has a chance of performing a interpolation of the data samples +7. ``RandGaussianNoised`` has a chance of adding noise to ``img`` + +.. figure:: ../images/lazy_resampling_trad_example_1.svg + + Figure showing traditional pipeline execution. Tensors (the boxes in the main body of the image) are passed through + the pipeline, and the state of their `applied_operations` property is shown at each step. Tensors with a thick red + border have undergone some kind of resample operation at that stage. + +Overall, there are up to three occasions where the data is either interpolated or resampled through spatial transforms +(``Spacingd``, ``RandRotated`` and ``RandZoomd``). Furthermore, the crop that occurs means that the output data +samples might contain pixels for which there is data but that show padding values, because the data was thrown away by +``RandSpatialCrop``. + +Each of these operations takes time and memory, but, as we can see in the example above, also creates resampling +artifacts and can even destroy data in the resulting data samples. + +Lazy execution +++++++++++++++ + +Lazy resampling works very differently. When you execute the same pipeline with `lazy=True`, the following happens: + +#. ``Spacingd`` is executed lazily. It puts a description of the operation that it wants to perform onto a list of + pending operations +#. ``Orientationd`` is executed lazily. It adds a description of its own operation to the pending operation list so + now there are 2 pending operations +#. ``RandSpatialCropd`` is executed lazily. It adds a description of its own operation to the pending + operation list so now there are 3 pending operations +#. ``RandRotate90d`` is executed lazily. It adds a description of its own operation to the pending operation + list so now there are 4 pending operations +#. ``RandRotated`` is executed lazily. It adds a description of its own operation to the pending operation + list so now there are 5 pending operations +#. ``RandZoomd`` is executed lazily. It adds a description of its own operation to the pending operation + list so now there are 6 pending operations + + #. [Spacingd, Orientationd, RandSpatialCropd, RandRotate90d, RandRotated, RandZoomd] are all on the pending + operations list but have yet to be carried out on the data +#. ``RandGaussianNoised`` is not a lazy transform. It is now time for the pending operations to be evaluated. Their + descriptions are mathematically composited together, to determine the operation that results from all of them being + carried out. This is then applied in a single resample operation. Once that is done, RandGaussianNoised operates on + the resulting data + +.. figure:: ../images/lazy_resampling_lazy_example_1.svg + + Figure showing lazy pipeline execution. We show the state of the `pending_operations` and `applied_operations` + properties of the tensor as it is processed by the pipeline. Thick red borders indicate some kind of resampling + operation has taken place at that step. Lazy resampling performs far fewer of these operations. + +The single resampling operation has less noise induced by resampling, as it only occurs once in this pipeline rather +than three times in the traditional pipeline. More importantly, although the crop describes an operation to keep only a +subset of the data sample, the crop is not performed until after the spatial transforms are completed, which means that +all of the data sample that is within bounds is preserved and is part of the resulting output. + + +Composing homogeneous matrices +++++++++++++++++++++++++++++++ + +.. image:: ../images/lazy_resampling_homogeneous_matrices.svg + + +Although a full treatment of homogeneous matrices is outside the scope of this document, a brief overview of them is +useful to understand the mechanics of lazy resampling. Homogeneous matrices are used in computer graphics to describe +operations in cartesian space in a unified (homogeneous) fashion. Rotation, scaling, translation, and skewing are +amongst the operations that can be performed. Homogeneous matrices have the interesting property that they can be +composited together, thus describing the result of a sequence of operations. Note that ordering is important; +`scale -> rotate -> translation` gives a very different result to `translation -> rotate -> scale`. + +The ability to composite homogeneous matrices together allows a sequence of operations to be carried out as a single +operation, which is the key mechanism by which lazy resampling functions. + + +API changes +^^^^^^^^^^^ + +A number of new arguments have been added to existing properties, which we'll go over in detail here. In particular, +we'll focus on :class:`Compose and +:class:`LazyTrait`/ :class:`LazyTransform` +and the way that they interact with each other. + + +Compose ++++++++ + +:class:`Compose` gains a number of new arguments that can be used to control +resampling behaviour. Each of them is covered in its own section: + + +lazy +"""" + +``lazy`` controls whether execution is carried out in a lazy manner or not. It has three values that it can take: + +* `lazy=False` forces the pipeline to be executed in the standard way with every transform applied immediately +* `lazy=True` forces the pipeline to be executed lazily. Every transform that implements + :class:`LazyTrait` (or inherits + :class:`LazyTransform`) will be executed lazily +* `lazy=None` means that the pipeline can execute lazily, but only on transforms that have their own `lazy` property + set to True. + + +overrides +""""""""" + +``overrides`` allows the user to specify certain parameters that transforms can be overridden with when they are +executed lazily. This parameter is primarily provided to allow you to run a pipeline without having to modify fields +like ``mode`` and ``padding_mode``. +When executing dictionary-based transforms, you provide a dictionary containing overrides for each key, as follows. You +can omit keys that don't require overrides: + +.. code-block:: + + { + "image": {"mode": "bilinear"}, + "label": {"padding_mode": "zeros"} + } + + +log_stats +""""""""" + +Logging of transform execution is provided if you wish to understand exactly how your pipelines execute. It can take a +``bool`` or ``str`` value, and is False by default, which disables logging. Otherwise, you can enable it by passing it +the name of a logger that you wish to use (note, you don't have to construct the logger beforehand). + + +LazyTrait / LazyTransform ++++++++++++++++++++++++++ + +Many transforms now implement either `LazyTrait` or +`LazyTransform`. Doing so marks the transform for lazy execution. Lazy +transforms have the following in common: + + +``__init__`` has a ``lazy`` argument +"""""""""""""""""""""""""""""""""""" + +``lazy`` is a ``bool`` value that can be passed to the initialiser when a lazy transform is instantiated. This +indicates to the transform that it should execute lazily or not lazily. Note that this value can be overridden by +passing ``lazy`` to ``__init__``. ``lazy`` is ``False`` by default + + +``__call__`` has a ``lazy`` argument +"""""""""""""""""""""""""""""""""""" + +``lazy`` is an optional ``bool`` value that can be passed at call time to override the behaviour defined during +initialisation. It has a default value of ``None``. If it is not ``None``, then this value is used instead of +``self.lazy``. This allows the calling :class:`Compose` instance to override +default values rather than having to set it on every lazy transform (unless the user sets +:class:`Compose.lazy` to ``None``). + + +lazy property +""""""""""""" + +The lazy property allows you to get or set the lazy status of a lazy transform after constructing it. + + +requires_current_data property (get only) +""""""""""""""""""""""""""""""""""""""""" + +The ``requires_current_data`` property indicates that a transform makes use of the data in one or more of the tensors +that it is passed during its execution. Such transforms require that the tensors must therefore be up to date, even if +the transform itself is executing lazily. This is required for transforms such as ``CropForeground[d]``, +``RandCropByPosNegLabel[d]``, and ``RandCropByLabelClasses[d]``. This property is implemented to return ``False`` on +``LazyTransform`` and must be overridden to return ``True`` by transforms that check data values when executing. + + +Controlling laziness +^^^^^^^^^^^^^^^^^^^^ + +There are two ways that a user can provide more fine-grained control over laziness. One is to make use of lazy=None +when initialising or calling ``Compose`` instances. The other is to use the ``ApplyPending[d]`` transforms. These +techniques can be freely mixed and matched. + + +Using ``lazy=None`` ++++++++++++++++++++ + +``Lazy=None`` tells ``Compose`` to honor the lazy flags set on each lazy transform. These are set to False by default +so the user must set lazy=True on the transforms that they still wish to execute lazily. + + +``lazy=None`` example: +"""""""""""""""""""""" + +.. figure:: ../images/lazy_resampling_none_example.svg + + Figure shwoing the effect of using ``lazy=False`` when ``Compose`` is being executed with ``lazy=None``. Note that + the additional resamples that occur due to ``RandRotate90d`` being executed in a non-lazy fashion. + + +Using ``ApplyPending[d]`` ++++++++++++++++++++++++++ + +``ApplyPending[d]`` causes all pending transforms to be executed before the following transform, regardless of whether +the following transform is a lazy transform, or is configured to execute lazily. + + +``ApplyPending`` Example: +""""""""""""""""""""""""" + +.. figure:: ../images/lazy_resampling_apply_pending_example.svg + + Figure showing the use of :class:`ApplyPendingd` to cause + resampling to occur in the midele of a chain of lazy transforms. diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index e045a7e741..fe17fa4efe 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -958,6 +958,17 @@ MRI Transforms :special-members: __call__ +Lazy +^^^^ + +`ApplyPending` +"""""""""""""" + +.. autoclass:: ApplyPending + :members: + :special-members: __call__ + + Utility ^^^^^^^ @@ -1912,6 +1923,17 @@ Smooth Field (Dict) :special-members: __call__ +Lazy (Dict) +^^^^^^^^^^^ + +`ApplyPendingd` +""""""""""""""" + +.. autoclass:: ApplyPendingd + :members: + :special-members: __call__ + + Utility (Dict) ^^^^^^^^^^^^^^ @@ -2211,9 +2233,3 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: - -Lazy ----- -.. automodule:: monai.transforms.lazy - :members: - :imported-members: diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index a692a42369..52b1a7d15d 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -43,7 +43,7 @@ from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image from monai.data.meta_tensor import MetaTensor, get_track_meta from monai.data.utils import orientation_ras_lps -from monai.transforms import Flip, RandFlip, RandRotate90d, RandZoom, Rotate90, SpatialCrop, Zoom +from monai.transforms import Flip, RandFlip, RandZoom, Rotate90, SpatialCrop, Zoom from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices @@ -1291,7 +1291,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch return d -class RandRotateBox90d(RandRotate90d): +class RandRotateBox90d(RandomizableTransform, MapTransform, InvertibleTransform): """ With probability `prob`, input boxes and images are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1323,7 +1323,13 @@ def __init__( ) -> None: self.image_keys = ensure_tuple(image_keys) self.box_keys = ensure_tuple(box_keys) - super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys) + + MapTransform.__init__(self, self.image_keys + self.box_keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + + self.max_k = max_k + self.spatial_axes = spatial_axes + self._rand_k = 0 self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]: @@ -1364,6 +1370,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t self.push_transform(d[key], extra_info=xform) return d + def randomize(self, data: Any | None = None) -> None: + self._rand_k = self.R.randint(self.max_k) + 1 + super().randomize(None) + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) if self._rand_k % 4 == 0: diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index 11454b0b6b..c166740768 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -20,8 +20,8 @@ from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms import InvertibleTransform from monai.transforms.croppad.array import SpatialCrop -from monai.transforms.croppad.dictionary import Cropd from monai.transforms.intensity.array import NormalizeIntensity from monai.transforms.transform import MapTransform, RandomizableTransform from monai.utils import FastMRIKeys @@ -190,7 +190,7 @@ def set_random_state( return self -class ReferenceBasedSpatialCropd(Cropd): +class ReferenceBasedSpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. This is similar to :py:class:`monai.transforms.SpatialCropd` which is a @@ -213,7 +213,7 @@ class ReferenceBasedSpatialCropd(Cropd): """ def __init__(self, keys: KeysCollection, ref_key: str, allow_missing_keys: bool = False) -> None: - super().__init__(keys, cropper=None, allow_missing_keys=allow_missing_keys) # type: ignore + MapTransform.__init__(self, keys, allow_missing_keys) self.ref_key = ref_key def __call__(self, data: Mapping[Hashable, Tensor]) -> dict[Hashable, Tensor]: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 75cbec5607..84817d17b0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -230,7 +230,9 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict -from .lazy.functional import apply_transforms +from .lazy.array import ApplyPending +from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict +from .lazy.functional import apply_pending from .lazy.utils import combine_transforms, resample from .meta_utility.dictionary import ( FromMetaTensord, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ec6ec6a0fe..d6b948ea71 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -22,13 +22,13 @@ import numpy as np import monai -import monai.transforms as mt from monai.apps.utils import get_logger from monai.config import NdarrayOrTensor from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import ThreadUnsafe # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) +from monai.transforms.lazy.functional import apply_pending_transforms +from monai.transforms.traits import ThreadUnsafe from monai.transforms.transform import ( # noqa: F401 LazyTransform, MapTransform, @@ -37,85 +37,11 @@ Transform, apply_transform, ) -from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed, to_tuple_of_dictionaries +from monai.utils import MAX_SEED, TraceKeys, TraceStatusKeys, ensure_tuple, get_seed logger = get_logger(__name__) -__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides", "SomeOf"] - - -def evaluate_with_overrides( - data, - upcoming, - lazy_evaluation: bool | None = False, - overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, -): - """ - The previously applied transform may have been lazily applied to MetaTensor `data` and - made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``, - this function determines whether `data.pending_operations` should be evaluated. If so, it will - evaluate the lazily applied transforms. - - Currently, the conditions for evaluation are: - - - ``lazy_evaluation`` is ``True``, AND - - the data is a ``MetaTensor`` and has pending operations, AND - - the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``. - - The returned `data` will then be ready for the ``upcoming`` transform. - - Args: - data: data to be evaluated. - upcoming: the upcoming transform. - lazy_evaluation: whether to evaluate the pending operations. - override: keyword arguments to apply transforms. - override_keys: to which the override arguments are used when apply transforms. - verbose: whether to print debugging info when evaluate MetaTensor with pending operations. - - """ - if not lazy_evaluation: - return data # eager evaluation - overrides = (overrides or {}).copy() - if isinstance(data, monai.data.MetaTensor): - if data.has_pending_operations and ( - (upcoming is None) - or (isinstance(upcoming, mt.Identity)) - or (isinstance(upcoming, mt.Identityd) and override_keys in upcoming.keys) - ): - data, _ = mt.apply_transforms(data, None, overrides=overrides) - if verbose: - next_name = "final output" if upcoming is None else f"'{upcoming.__class__.__name__}'" - logger.info(f"Evaluated - '{override_keys}' - up-to-date for - {next_name}") - elif verbose: - logger.info( - f"Lazy - '{override_keys}' - upcoming: '{upcoming.__class__.__name__}'" - f"- pending {len(data.pending_operations)}" - ) - return data - override_keys = ensure_tuple(override_keys) - if isinstance(data, dict): - if isinstance(upcoming, MapTransform): - applied_keys = {k for k in data if k in upcoming.keys} - if not applied_keys: - return data - else: - applied_keys = set(data.keys()) - - keys_to_override = {k for k in applied_keys if k in override_keys} - # generate a list of dictionaries with the appropriate override value per key - dict_overrides = to_tuple_of_dictionaries(overrides, override_keys) - for k in data: - if k in keys_to_override: - dict_for_key = dict_overrides[override_keys.index(k)] - data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, dict_for_key, k, verbose) - else: - data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, None, k, verbose) - - if isinstance(data, (list, tuple)): - return [evaluate_with_overrides(v, upcoming, lazy_evaluation, overrides, override_keys, verbose) for v in data] - return data +__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"] def execute_compose( @@ -125,12 +51,10 @@ def execute_compose( unpack_items: bool = False, start: int = 0, end: int | None = None, - lazy_evaluation: bool = False, + lazy: bool | None = False, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, threading: bool = False, - log_stats: bool = False, - verbose: bool = False, + log_stats: bool | str = False, ) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]: """ ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence @@ -146,28 +70,22 @@ def execute_compose( unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. start: the index of the first transform to be executed. If not set, this defaults to 0 - end: the index after the last transform to be exectued. If set, the transform at index-1 + end: the index after the last transform to be executed. If set, the transform at index-1 is the last transform that is executed. If this is not set, it defaults to len(transforms) - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + lazy: whether to enable :ref:`lazy evaluation` for lazy transforms. If False, transforms will be carried out on a transform by transform basis. If True, all lazy transforms will be executed by accumulating changes and resampling as few times as possible. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of - the pending operations and make the primary data up-to-date. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. - currently supported args are: - {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, - please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`lazy evaluation` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. threading: whether executing is happening in a threaded environment. If set, copies are made of transforms that have the ``RandomizedTrait`` interface. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - verbose: whether to print debugging info when lazy_evaluation=True. + log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. Returns: A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running @@ -176,6 +94,8 @@ def execute_compose( end_ = len(transforms) if end is None else end if start is None: raise ValueError(f"'start' ({start}) cannot be None") + if start < 0: + raise ValueError(f"'start' ({start}) cannot be less than 0") if start > end_: raise ValueError(f"'start' ({start}) must be less than 'end' ({end_})") if end_ > len(transforms): @@ -188,18 +108,10 @@ def execute_compose( for _transform in transforms[start:end]: if threading: _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - data = evaluate_with_overrides( - data, - _transform, - lazy_evaluation=lazy_evaluation, - overrides=overrides, - override_keys=override_keys, - verbose=verbose, + data = apply_transform( + _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats ) - data = apply_transform(_transform, data, map_items, unpack_items, log_stats) - data = evaluate_with_overrides( - data, None, lazy_evaluation=lazy_evaluation, overrides=overrides, override_keys=override_keys, verbose=verbose - ) + data = apply_pending_transforms(data, None, overrides, logger_name=log_stats) return data @@ -272,35 +184,24 @@ class Compose(Randomizable, InvertibleTransform): them are called on the labels. Lazy resampling: + Lazy resampling is an experimental feature introduced in 1.2. Its purpose is to reduce the number of resample operations that must be carried out when executing a pipeline of transforms. This can provide significant performance improvements in - terms of pipeline executing speed and memory usage, but can also significantly + terms of pipeline executing speed and memory usage, and can also significantly reduce the loss of information that occurs when performing a number of spatial resamples in succession. - Lazy resampling can be thought of as acting in a similar fashion to the `Affine` & `RandAffine` - transforms, in that they allow several spatial transform operations can be specified and carried out with - a single resample step. Unlike these transforms, however, lazy resampling can operate on any set of - transforms specified in any ordering. The user is free to mix monai transforms with transforms from other - libraries; lazy resampling will determine the minimum number of resample steps required in order to - execute the pipeline. - - Lazy resampling works with monai `Dataset` classes that provide caching and persistence. However, if you - are implementing your own caching dataset implementation and wish to make use of lazy resampling, you - should ensure that you fully execute the part of the pipeline that generates the data to be cached - before caching it. This is quite simply done however, as shown by the following example. - - Example: - # run the part of the pipeline that needs to be cached - data = self.transform(data, end=self.post_cache_index) - - # --- + Lazy resampling can be enabled or disabled through the ``lazy`` parameter, either by + specifying it at initialisation time or overriding it at call time. - # fetch the data from the cache and run the rest of the pipeline - data = get_data_from_my_cache(data) - data = self.transform(data, start=self.post_cache_index) + * False (default): Don't perform any lazy resampling + * None: Perform lazy resampling based on the 'lazy' properties of the transform instances. + * True: Always perform lazy resampling if possible. This will ignore the ``lazy`` properties + of the transform instances + Please see the :ref:`Lazy Resampling topic` for more details of this feature + and examples of its use. Args: transforms: sequence of callables. @@ -308,24 +209,19 @@ class Compose(Randomizable, InvertibleTransform): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be - carried out on a transform by transform basis. If True, all lazy transforms will - be executed by accumulating changes and resampling as few times as possible. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of - the pending operations and make the primary data up-to-date. + log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. + lazy: whether to enable :ref:`Lazy Resampling` for lazy transforms. If False, transforms will + be carried out on a transform by transform basis. If True, all lazy transforms will be executed by + accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will + perform lazy execution on lazy transforms that have their `lazy` property set to True. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. - currently supported args are: - {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, - please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`Lazy Resampling` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. """ def __init__( @@ -333,11 +229,9 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + log_stats: bool | str = False, + lazy: bool | None = False, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, ) -> None: if transforms is None: transforms = [] @@ -346,16 +240,8 @@ def __init__( self.unpack_items = unpack_items self.log_stats = log_stats self.set_random_state(seed=get_seed()) - - self.lazy_evaluation = lazy_evaluation + self.lazy = lazy self.overrides = overrides - self.override_keys = override_keys - self.verbose = verbose - - if self.lazy_evaluation is not None: - for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf - if isinstance(t, LazyTransform): - t.lazy_evaluation = self.lazy_evaluation def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose: super().set_random_state(seed=seed, state=state) @@ -432,52 +318,56 @@ def __len__(self): """Return number of transformations.""" return len(self.flatten().transforms) - def evaluate_with_overrides(self, input_, upcoming_xform): - """ - Args: - input_: input data to be transformed. - upcoming_xform: a transform used to determine whether to evaluate with override - """ - return evaluate_with_overrides( + def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + result = execute_compose( input_, - upcoming_xform, - lazy_evaluation=self.lazy_evaluation, - overrides=self.overrides, - override_keys=self.override_keys, - verbose=self.verbose, - ) - - def __call__(self, input_, start=0, end=None, threading=False): - return execute_compose( - input_, - self.transforms, + transforms=self.transforms, start=start, end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy_evaluation=self.lazy_evaluation, # type: ignore + lazy=self.lazy, # type: ignore overrides=self.overrides, - override_keys=self.override_keys, threading=threading, log_stats=self.log_stats, - verbose=self.verbose, ) + return result + def inverse(self, data): + self._raise_if_not_invertible(data) + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") + if self.lazy is not False: + warnings.warn( + f"'lazy' is set to {self.lazy} but lazy execution is not supported when inverting. " + f"'lazy' has been overridden to False for the call to inverse" + ) # loop backwards over transforms for t in reversed(invertible_transforms): - if isinstance(t, LazyTransform) and t.lazy_evaluation: - warnings.warn( - f"inversing {t.__class__.__name__} lazily may not implemented" - "please set `lazy_evaluation=False` before calling inverse." - ) - data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) + data = apply_transform( + t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats + ) return data + @staticmethod + def _raise_if_not_invertible(data: Any): + from monai.transforms.utils import has_status_keys + + invertible, reasons = has_status_keys( + data, TraceStatusKeys.PENDING_DURING_APPLY, "Pending operations while applying an operation" + ) + + if invertible is False: + if reasons is not None: + reason_text = "\n".join(reasons) + raise RuntimeError(f"Unable to run inverse on 'data' for the following reasons:\n{reason_text}") + else: + raise RuntimeError("Unable to run inverse on 'data'; no reason logged in trace data") + class OneOf(Compose): """ @@ -492,24 +382,19 @@ class OneOf(Compose): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will - be executed by accumulating changes and resampling as few times as possible. If False, transforms will be - carried out on a transform by transform basis. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of - the pending operations and make the primary data up-to-date. + log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. + lazy: whether to enable :ref:`Lazy Resampling` for lazy transforms. If False, transforms will + be carried out on a transform by transform basis. If True, all lazy transforms will be executed by + accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will + perform lazy execution on lazy transforms that have their `lazy` property set to True. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. - currently supported args are: - {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, - please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`Lazy Resampling` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. """ def __init__( @@ -518,15 +403,11 @@ def __init__( weights: Sequence[float] | float | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + log_stats: bool | str = False, + lazy: bool | None = False, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, ) -> None: - super().__init__( - transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose - ) + super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides) if len(self.transforms) == 0: weights = [] elif weights is None or isinstance(weights, float): @@ -537,6 +418,7 @@ def __init__( f"got {len(weights)} and {len(self.transforms)}." ) self.weights = ensure_tuple(self._normalize_probabilities(weights)) + self.log_stats = log_stats def _normalize_probabilities(self, weights): if len(weights) == 0: @@ -565,7 +447,12 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data, start=0, end=None, threading=False): + def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | None = None): + if start != 0: + raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"OneOf requires 'end' parameter to be None (end set to {end}") + if len(self.transforms) == 0: return data @@ -575,11 +462,14 @@ def __call__(self, data, start=0, end=None, threading=False): data = execute_compose( data, [_transform], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, # type: ignore + overrides=self.overrides, threading=threading, + log_stats=self.log_stats, ) # if the data is a mapping (dictionary), append the OneOf transform to the end @@ -625,24 +515,19 @@ class RandomOrder(Compose): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. - lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will - be executed by accumulating changes and resampling as few times as possible. If False, transforms will be - carried out on a transform by transform basis. - A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of - the pending operations and make the primary data up-to-date. + log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. + lazy: whether to enable :ref:`Lazy Resampling` for lazy transforms. If False, transforms will + be carried out on a transform by transform basis. If True, all lazy transforms will be executed by + accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will + perform lazy execution on lazy transforms that have their `lazy` property set to True. overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden when executing a pipeline. These each parameter that is compatible with a given transform is then applied - to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation - is True. If lazy_evaluation is False they are ignored. - currently supported args are: - {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, - please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. - override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If - ``overrides`` is set, ``override_keys`` must also be set. - verbose: whether to print debugging info when lazy_evaluation=True. + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`Lazy Resampling` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. """ def __init__( @@ -650,30 +535,35 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - lazy_evaluation: bool | None = None, + log_stats: bool | str = False, + lazy: bool | None = False, overrides: dict | None = None, - override_keys: Sequence[str] | None = None, - verbose: bool = False, ) -> None: - super().__init__( - transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose - ) + super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides) + self.log_stats = log_stats + + def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + if start != 0: + raise ValueError(f"RandomOrder requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"RandomOrder requires 'end' parameter to be None (end set to {end}") - def __call__(self, input_, start=0, end=None, threading=False): if len(self.transforms) == 0: return input_ + num = len(self.transforms) applied_order = self.R.permutation(range(num)) input_ = execute_compose( input_, [self.transforms[ind] for ind in applied_order], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, threading=threading, + log_stats=self.log_stats, ) # if the data is a mapping (dictionary), append the RandomOrder transform to the end @@ -708,7 +598,7 @@ def inverse(self, data): for o in reversed(applied_order): if isinstance(self.transforms[o], InvertibleTransform): data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats + self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats ) return data @@ -727,14 +617,24 @@ class SomeOf(Compose): Defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. Defaults to `False`. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. Default to `False`. + log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal. None sets it to `len(transforms)`. Default to `None`. replace: whether to sample with replacement. Defaults to `False`. weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform). + lazy: whether to enable :ref:`Lazy Resampling` for lazy transforms. If False, transforms will + be carried out on a transform by transform basis. If True, all lazy transforms will be executed by + accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will + perform lazy execution on lazy transforms that have their `lazy` property set to True. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`Lazy Resampling` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. """ def __init__( @@ -742,16 +642,18 @@ def __init__( transforms: Sequence[Callable] | Callable | None = None, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, - *, + log_stats: bool | str = False, num_transforms: int | tuple[int, int] | None = None, replace: bool = False, weights: list[int] | None = None, + lazy: bool | None = False, + overrides: dict | None = None, ) -> None: - super().__init__(transforms, map_items, unpack_items, log_stats) + super().__init__(transforms, map_items, unpack_items, log_stats=log_stats, lazy=lazy, overrides=overrides) self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms) self.replace = replace self.weights = self._normalize_probabilities(weights) + self.log_stats = log_stats def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple: if ( @@ -805,7 +707,12 @@ def _normalize_probabilities(self, weights): return ensure_tuple(list(weights)) - def __call__(self, data, start=0, end=None, threading=False): + def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None): + if start != 0: + raise ValueError(f"SomeOf requires 'start' parameter to be 0 (start set to {start})") + if end is not None: + raise ValueError(f"SomeOf requires 'end' parameter to be None (end set to {end}") + if len(self.transforms) == 0: return data @@ -815,11 +722,14 @@ def __call__(self, data, start=0, end=None, threading=False): data = execute_compose( data, [self.transforms[a] for a in applied_order], - map_items=self.map_items, - unpack_items=self.unpack_items, start=start, end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy=self.lazy, + overrides=self.overrides, threading=threading, + log_stats=self.log_stats, ) if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"applied_order": applied_order}) @@ -852,10 +762,9 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): - transform = self.transforms[o] - if isinstance(transform, InvertibleTransform): + if isinstance(self.transforms[o], InvertibleTransform): data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats + self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats ) return data diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 8cfd2c70ef..740ea9d8f5 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -88,6 +88,9 @@ class Pad(InvertibleTransform, LazyTransform): `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, in which case `np.pad` will be used. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: to_pad: the amount to pad in each dimension (including the channel) [(low_H, high_H), (low_W, high_W), ...]. if None, must provide in the `__call__` at runtime. @@ -98,6 +101,7 @@ class Pad(InvertibleTransform, LazyTransform): See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html requires pytorch >= 1.10 for best compatibility. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -106,8 +110,13 @@ class Pad(InvertibleTransform, LazyTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs + self, + to_pad: tuple[tuple[int, int]] | None = None, + mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, + **kwargs, ) -> None: + LazyTransform.__init__(self, lazy) self.to_pad = to_pad self.mode = mode self.kwargs = kwargs @@ -124,7 +133,12 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") def __call__( # type: ignore[override] - self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs + self, + img: torch.Tensor, + to_pad: tuple[tuple[int, int]] | None = None, + mode: str | None = None, + lazy: bool | None = None, + **kwargs, ) -> torch.Tensor: """ Args: @@ -137,6 +151,7 @@ def __call__( # type: ignore[override] One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -150,7 +165,8 @@ def __call__( # type: ignore[override] kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_) + lazy_ = self.lazy if lazy is None else lazy + return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, lazy_, **kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -170,6 +186,9 @@ class SpatialPad(Pad): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: the spatial size of output data after padding, if a dimension of the input data size is larger than the pad size, will not pad that dimension. @@ -184,6 +203,7 @@ class SpatialPad(Pad): One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -194,11 +214,12 @@ def __init__( spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...], method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = look_up_option(method, Method) - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ @@ -223,6 +244,9 @@ class BorderPad(Pad): """ Pad the input data by adding specified borders to every dimension. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_border: specified size for every spatial border. Any -ve values will be set to 0. It can be 3 shapes: @@ -240,14 +264,17 @@ class BorderPad(Pad): One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, **kwargs) -> None: + def __init__( + self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, lazy: bool = False, **kwargs + ) -> None: self.spatial_border = spatial_border - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) @@ -274,12 +301,20 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in class DivisiblePad(Pad): """ Pad the input data, so that the spatial sizes are divisible by `k`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = SpatialPad.backend def __init__( - self, k: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, **kwargs + self, + k: Sequence[int] | int, + mode: str = PytorchPadMode.CONSTANT, + method: str = Method.SYMMETRIC, + lazy: bool = False, + **kwargs, ) -> None: """ Args: @@ -294,6 +329,7 @@ def __init__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -301,7 +337,7 @@ def __init__( """ self.k = k self.method: Method = Method(method) - super().__init__(mode=mode, **kwargs) + super().__init__(mode=mode, lazy=lazy, **kwargs) def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) @@ -313,10 +349,18 @@ class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + + Args: + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = [TransformBackends.TORCH] + def __init__(self, lazy: bool = False): + LazyTransform.__init__(self, lazy) + @staticmethod def compute_slices( roi_center: Sequence[int] | NdarrayOrTensor | None = None, @@ -370,7 +414,9 @@ def compute_slices( [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] ) - def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override] + def __call__( # type: ignore[override] + self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None + ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -384,7 +430,8 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor slices_ = list([slice(None)] + slices_[:sd]) img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - return crop_func(img_t, tuple(slices_), self.get_transform_info()) + lazy_ = self.lazy if lazy is None else lazy + return crop_func(img_t, tuple(slices_), lazy_, self.get_transform_info()) def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) @@ -408,6 +455,9 @@ class SpatialCrop(Crop): - a list of slices for each spatial dimension (allows for use of negative indexing and `None`) - a spatial center and size - the start and end coordinates of the ROI + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ def __init__( @@ -417,6 +467,7 @@ def __init__( roi_start: Sequence[int] | NdarrayOrTensor | None = None, roi_end: Sequence[int] | NdarrayOrTensor | None = None, roi_slices: Sequence[slice] | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -427,18 +478,22 @@ def __init__( roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. + """ + super().__init__(lazy) self.slices = self.compute_slices( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=ensure_tuple(self.slices)) + lazy_ = self.lazy if lazy is None else lazy + return super().__call__(img=img, slices=ensure_tuple(self.slices), lazy=lazy_) class CenterSpatialCrop(Crop): @@ -448,15 +503,20 @@ class CenterSpatialCrop(Crop): So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: roi_size: the spatial size of the crop region e.g. [224,224,128] if a dimension of ROI size is larger than image size, will not crop that dimension of the image. If its components have non-positive values, the corresponding size of input image will be used. for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40]. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ - def __init__(self, roi_size: Sequence[int] | int) -> None: + def __init__(self, roi_size: Sequence[int] | int, lazy: bool = False) -> None: + super().__init__(lazy=lazy) self.roi_size = roi_size def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override] @@ -464,15 +524,17 @@ def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + lazy_ = self.lazy if lazy is None else lazy return super().__call__( img=img, slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), + lazy=lazy_, ) @@ -480,21 +542,26 @@ class CenterScaleCrop(Crop): """ Crop at the center of image with specified scale of ROI size. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: roi_scale: specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5] or a number for all dims. If its components have non-positive values, will use `1.0` instead, which means the input image size. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ - def __init__(self, roi_scale: Sequence[float] | float): + def __init__(self, roi_scale: Sequence[float] | float, lazy: bool = False): + super().__init__(lazy=lazy) self.roi_scale = roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[override] img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] - cropper = CenterSpatialCrop(roi_size=roi_size) - return super().__call__(img=img, slices=cropper.compute_slices(img_size)) + lazy_ = self.lazy if lazy is None else lazy + cropper = CenterSpatialCrop(roi_size=roi_size, lazy=lazy_) + return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_) class RandSpatialCrop(Randomizable, Crop): @@ -506,6 +573,9 @@ class RandSpatialCrop(Randomizable, Crop): will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: roi_size: if `random_size` is True, it specifies the minimum crop region. if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] @@ -519,6 +589,7 @@ class RandSpatialCrop(Randomizable, Crop): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") @@ -528,7 +599,9 @@ def __init__( max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: + super().__init__(lazy) self.roi_size = roi_size self.max_roi_size = max_roi_size self.random_center = random_center @@ -547,7 +620,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -558,10 +631,11 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(img_size) if self._size is None: raise RuntimeError("self._size not specified.") + lazy_ = self.lazy if lazy is None else lazy if self.random_center: - return super().__call__(img=img, slices=self._slices) - cropper = CenterSpatialCrop(self._size) - return super().__call__(img=img, slices=cropper.compute_slices(img_size)) + return super().__call__(img=img, slices=self._slices, lazy=lazy_) + cropper = CenterSpatialCrop(self._size, lazy=lazy_) + return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_) class RandScaleCrop(RandSpatialCrop): @@ -571,6 +645,9 @@ class RandScaleCrop(RandSpatialCrop): center or at the image center. And allows to set the minimum and maximum scale of image size to limit the randomly generated ROI. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: roi_scale: if `random_size` is True, it specifies the minimum crop size: `roi_scale * image spatial size`. if `random_size` is False, it specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5]. @@ -583,6 +660,7 @@ class RandScaleCrop(RandSpatialCrop): random_size: crop with random size or specified size ROI by `roi_scale * image spatial size`. if True, the actual size is sampled from `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") @@ -592,8 +670,11 @@ def __init__( max_roi_scale: Sequence[float] | float | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: - super().__init__(roi_size=-1, max_roi_size=None, random_center=random_center, random_size=random_size) + super().__init__( + roi_size=-1, max_roi_size=None, random_center=random_center, random_size=random_size, lazy=lazy + ) self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale @@ -609,14 +690,15 @@ def randomize(self, img_size: Sequence[int]) -> None: self.get_max_roi_size(img_size) super().randomize(img_size) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ self.get_max_roi_size(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) - return super().__call__(img=img, randomize=randomize) + lazy_ = self.lazy if lazy is None else lazy + return super().__call__(img=img, randomize=randomize, lazy=lazy_) class RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait): @@ -630,6 +712,9 @@ class RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, Mu will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: roi_size: if `random_size` is True, it specifies the minimum crop region. if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] @@ -644,6 +729,7 @@ class RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, Mu random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. Raises: ValueError: When ``num_samples`` is nonpositive. @@ -660,11 +746,13 @@ def __init__( max_roi_size: Sequence[int] | int | None = None, random_center: bool = True, random_size: bool = True, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) + self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -673,25 +761,26 @@ def set_random_state( self.cropper.set_random_state(seed, state) return self - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value def randomize(self, data: Any | None = None) -> None: pass - def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> list[torch.Tensor]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ ret = [] + lazy_ = self.lazy if lazy is None else lazy for i in range(self.num_samples): - cropped = self.cropper(img) + cropped = self.cropper(img, lazy=lazy_) if get_track_meta(): cropped.meta[Key.PATCH_INDEX] = i # type: ignore - self.push_transform(cropped, replace=True) # track as this class instead of RandSpatialCrop + self.push_transform(cropped, replace=True, lazy=lazy_) # track as this class instead of RandSpatialCrop ret.append(cropped) return ret @@ -726,6 +815,9 @@ def threshold_at_one(x): [3, 2], [2, 1]]] + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + """ def __init__( @@ -737,6 +829,7 @@ def __init__( return_coords: bool = False, k_divisible: Sequence[int] | int = 1, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **pad_kwargs, ) -> None: """ @@ -757,22 +850,28 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ + LazyTransform.__init__(self, lazy) self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None self.margin = margin self.allow_smaller = allow_smaller self.return_coords = return_coords self.k_divisible = k_divisible - self.padder = Pad(mode=mode, **pad_kwargs) + self.padder = Pad(mode=mode, lazy=lazy, **pad_kwargs) - @Crop.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val - self.padder.lazy_evaluation = _val + @Crop.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + self.padder.lazy = _val + + @property + def requires_current_data(self): + return False def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]: """ @@ -794,14 +893,20 @@ def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarra return box_start_, box_end_ def crop_pad( - self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs + self, + img: torch.Tensor, + box_start: np.ndarray, + box_end: np.ndarray, + mode: str | None = None, + lazy: bool = False, + **pad_kwargs, ) -> torch.Tensor: """ Crop and pad based on the bounding box. """ slices = self.compute_slices(roi_start=box_start, roi_end=box_end) - cropped = super().__call__(img=img, slices=slices) + cropped = super().__call__(img=img, slices=slices, lazy=lazy) pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum( box_end - np.asarray(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), 0 @@ -810,11 +915,11 @@ def crop_pad( pad_width = BorderPad(spatial_border=pad).compute_pad_width( cropped.peek_pending_shape() if isinstance(cropped, MetaTensor) else cropped.shape[1:] ) - ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, **pad_kwargs) + ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, lazy=lazy, **pad_kwargs) # combine the traced cropping and padding into one transformation # by taking the padded info and placing it in a key inside the crop info. if get_track_meta() and isinstance(ret, MetaTensor): - if not self.lazy_evaluation: + if not lazy: ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop() else: pad_info = ret.pending_operations.pop() @@ -826,19 +931,21 @@ def crop_pad( orig_size=crop_info.get(TraceKeys.ORIG_SIZE), sp_size=pad_info[LazyAttr.SHAPE], affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE], + lazy=lazy, extra_info=extra, ) return ret def __call__( # type: ignore[override] - self, img: torch.Tensor, mode: str | None = None, **pad_kwargs + self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ box_start, box_end = self.compute_bounding_box(img) - cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) + lazy_ = self.lazy if lazy is None else lazy + cropped = self.crop_pad(img, box_start, box_end, mode, lazy=lazy_, **pad_kwargs) if self.return_coords: return cropped, box_start, box_end # type: ignore[return-value] @@ -859,6 +966,9 @@ class RandWeightedCrop(Randomizable, TraceableTransform, LazyTransform, MultiSam """ Samples a list of `num_samples` image patches according to the provided `weight_map`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: the spatial size of the image patch e.g. [224, 224, 128]. If its components have non-positive values, the corresponding size of `img` will be used. @@ -866,13 +976,19 @@ class RandWeightedCrop(Randomizable, TraceableTransform, LazyTransform, MultiSam weight_map: weight map used to generate patch samples. The weights must be non-negative. Each element denotes a sampling weight of the spatial location. 0 indicates no sampling. It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = SpatialCrop.backend def __init__( - self, spatial_size: Sequence[int] | int, num_samples: int = 1, weight_map: NdarrayOrTensor | None = None + self, + spatial_size: Sequence[int] | int, + num_samples: int = 1, + weight_map: NdarrayOrTensor | None = None, + lazy: bool = False, ): + LazyTransform.__init__(self, lazy) self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map @@ -883,12 +999,16 @@ def randomize(self, weight_map: NdarrayOrTensor) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val def __call__( - self, img: torch.Tensor, weight_map: NdarrayOrTensor | None = None, randomize: bool = True + self, + img: torch.Tensor, + weight_map: NdarrayOrTensor | None = None, + randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -897,6 +1017,7 @@ def __call__( Each element denotes a sampling weight of the spatial location. 0 indicates no sampling. It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)` randomize: whether to execute random operations, default to `True`. + lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None. Returns: A list of image patches @@ -915,15 +1036,15 @@ def __call__( _spatial_size = fall_back_tuple(self.spatial_size, img_shape) results: list[torch.Tensor] = [] + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -947,6 +1068,9 @@ class RandCropByPosNegLabel(Randomizable, TraceableTransform, LazyTransform, Mul And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the valid crop ROI. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. if a dimension of ROI size is larger than image size, will not crop that dimension of the image. @@ -976,6 +1100,7 @@ class RandCropByPosNegLabel(Randomizable, TraceableTransform, LazyTransform, Mul allow_smaller: if `False`, an exception will be raised if the image is smaller than the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -997,7 +1122,9 @@ def __init__( fg_indices: NdarrayOrTensor | None = None, bg_indices: NdarrayOrTensor | None = None, allow_smaller: bool = False, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) self.spatial_size = spatial_size self.label = label if pos < 0 or neg < 0: @@ -1044,9 +1171,13 @@ def randomize( self.allow_smaller, ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + + @property + def requires_current_data(self): + return False def __call__( self, @@ -1056,6 +1187,7 @@ def __call__( fg_indices: NdarrayOrTensor | None = None, bg_indices: NdarrayOrTensor | None = None, randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -1070,6 +1202,7 @@ def __call__( bg_indices: background indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. randomize: whether to execute the random operations, default to `True`. + lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None. """ if image is None: @@ -1082,15 +1215,15 @@ def __call__( if self.centers is not None: img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] roi_size = fall_back_tuple(self.spatial_size, default=img_shape) + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=center, roi_size=roi_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=center, roi_size=roi_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -1134,6 +1267,9 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform, LazyTransform, Mu And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the valid crop ROI. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. if a dimension of ROI size is larger than image size, will not crop that dimension of the image. @@ -1159,7 +1295,7 @@ class RandCropByLabelClasses(Randomizable, TraceableTransform, LazyTransform, Mu warn: if `True` prints a warning if a class is not present in the label. max_samples_per_class: maximum length of indices to sample in each class to reduce memory consumption. Default is None, no subsampling. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = SpatialCrop.backend @@ -1177,7 +1313,9 @@ def __init__( allow_smaller: bool = False, warn: bool = True, max_samples_per_class: int | None = None, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy) self.spatial_size = spatial_size self.ratios = ratios self.label = label @@ -1215,9 +1353,13 @@ def randomize( self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, _val: bool): - self._lazy_evaluation = _val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, _val: bool): + self._lazy = _val + + @property + def requires_current_data(self): + return False def __call__( self, @@ -1226,6 +1368,7 @@ def __call__( image: torch.Tensor | None = None, indices: list[NdarrayOrTensor] | None = None, randomize: bool = True, + lazy: bool | None = None, ) -> list[torch.Tensor]: """ Args: @@ -1236,7 +1379,7 @@ def __call__( use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. indices: list of indices for every class in the image, used to randomly select crop centers. randomize: whether to execute the random operations, default to `True`. - + lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None. """ if image is None: image = self.image @@ -1248,15 +1391,15 @@ def __call__( if self.centers is not None: img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] roi_size = fall_back_tuple(self.spatial_size, default=img_shape) + lazy_ = self.lazy if lazy is None else lazy for i, center in enumerate(self.centers): - cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size) - cropper.lazy_evaluation = self.lazy_evaluation + cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size, lazy=lazy_) cropped = cropper(img) if get_track_meta(): ret_: MetaTensor = cropped # type: ignore ret_.meta[Key.PATCH_INDEX] = i ret_.meta["crop_center"] = center - self.push_transform(ret_, replace=True) + self.push_transform(ret_, replace=True, lazy=lazy_) results.append(cropped) return results @@ -1269,6 +1412,9 @@ class ResizeWithPadOrCrop(InvertibleTransform, LazyTransform): When the dimension is smaller than the target size, do symmetric padding along that dim. When the dimension is larger than the target size, do central cropping along that dim. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: the spatial size of output data after padding or crop. If has non-positive values, the corresponding size of input image will be used (no padding). @@ -1282,6 +1428,7 @@ class ResizeWithPadOrCrop(InvertibleTransform, LazyTransform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ @@ -1292,18 +1439,22 @@ def __init__( spatial_size: Sequence[int] | int, method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **pad_kwargs, ): - self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) - self.cropper = CenterSpatialCrop(roi_size=spatial_size) + LazyTransform.__init__(self, lazy) + self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, lazy=lazy, **pad_kwargs) + self.cropper = CenterSpatialCrop(roi_size=spatial_size, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.padder.lazy_evaluation = val - self.cropper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.padder.lazy = val + self.cropper.lazy = val + self._lazy = val - def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> torch.Tensor: # type: ignore + def __call__( # type: ignore[override] + self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs + ) -> torch.Tensor: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1314,20 +1465,22 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - kwargs: other arguments for the `np.pad` or `torch.pad` function. + lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None. + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) + lazy_ = self.lazy if lazy is None else lazy + ret = self.padder(self.cropper(img, lazy_), mode=mode, lazy=lazy_, **pad_kwargs) # remove the individual info and combine if get_track_meta(): ret_: MetaTensor = ret # type: ignore - if not self.lazy_evaluation: + if not lazy_: pad_info = ret_.applied_operations.pop() crop_info = ret_.applied_operations.pop() orig_size = crop_info.get(TraceKeys.ORIG_SIZE) self.push_transform( - ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info} + ret_, orig_size=orig_size, extra_info={"pad_info": pad_info, "crop_info": crop_info}, lazy=lazy_ ) else: pad_info = ret_.pending_operations.pop() @@ -1339,6 +1492,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs) -> sp_size=pad_info[LazyAttr.SHAPE], affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE], extra_info={"pad_info": pad_info, "crop_info": crop_info}, + lazy=lazy_, ) return ret diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 8e9b6b2f1e..6dc8f10c32 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -47,7 +47,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import MultiSampleTrait +from monai.transforms.traits import LazyTrait, MultiSampleTrait from monai.transforms.transform import LazyTransform, MapTransform, Randomizable from monai.transforms.utils import is_positive from monai.utils import MAX_SEED, Method, PytorchPadMode, deprecated_arg_default, ensure_tuple_rep @@ -114,6 +114,8 @@ class Padd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Pad.backend @@ -124,6 +126,7 @@ def __init__( padder: Pad, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -138,22 +141,34 @@ def __init__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) + if lazy is True and not isinstance(padder, LazyTrait): + raise ValueError("'padder' must inherit LazyTrait if lazy is True " f"'padder' is of type({type(padder)})") self.padder = padder self.mode = ensure_tuple_rep(mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value if isinstance(self.padder, LazyTransform): - self.padder.lazy_evaluation = value + self.padder.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) + lazy_ = self.lazy if lazy is None else lazy + if lazy_ is True and not isinstance(self.padder, LazyTrait): + raise ValueError( + "'self.padder' must inherit LazyTrait if lazy is True " f"'self.padder' is of type({type(self.padder)}" + ) for key, m in self.key_iterator(d, self.mode): - d[key] = self.padder(d[key], mode=m) + if isinstance(self.padder, LazyTrait): + d[key] = self.padder(d[key], mode=m, lazy=lazy_) + else: + d[key] = self.padder(d[key], mode=m) + return d def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: @@ -168,6 +183,8 @@ class SpatialPadd(Padd): Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ def __init__( @@ -177,6 +194,7 @@ def __init__( method: str = Method.SYMMETRIC, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -198,18 +216,23 @@ def __init__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - padder = SpatialPad(spatial_size, method, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class BorderPadd(Padd): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = BorderPad.backend @@ -220,6 +243,7 @@ def __init__( spatial_border: Sequence[int] | int, mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -245,18 +269,23 @@ def __init__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - padder = BorderPad(spatial_border=spatial_border, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class DivisiblePadd(Padd): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = DivisiblePad.backend @@ -268,6 +297,7 @@ def __init__( mode: SequenceStr = PytorchPadMode.CONSTANT, method: str = Method.SYMMETRIC, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: """ @@ -287,44 +317,51 @@ def __init__( method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. See also :py:class:`monai.transforms.SpatialPad` """ - padder = DivisiblePad(k=k, method=method, **kwargs) - super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + LazyTransform.__init__(self, lazy) + padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) class Cropd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` cropper: crop transform for the input image. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = Crop.backend - def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): - super().__init__(keys, allow_missing_keys) + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False): + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.cropper = cropper - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value if isinstance(self.cropper, LazyTransform): - self.cropper.lazy_evaluation = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.cropper(d[key]) # type: ignore + d[key] = self.cropper(d[key], lazy=lazy_) # type: ignore return d def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: @@ -338,18 +375,21 @@ class RandCropd(Cropd, Randomizable): """ Base class for random crop transform. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` cropper: random crop transform for the input image. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = Crop.backend - def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False): - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False): + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandCropd: super().set_random_state(seed, state) @@ -361,13 +401,21 @@ def randomize(self, img_size: Sequence[int]) -> None: if isinstance(self.cropper, Randomizable): self.cropper.randomize(img_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) # the first key must exist to execute random operations first_item = d[self.first_key(d)] self.randomize(first_item.peek_pending_shape() if isinstance(first_item, MetaTensor) else first_item.shape[1:]) + lazy_ = self.lazy if lazy is None else lazy + if lazy_ is True and not isinstance(self.cropper, LazyTrait): + raise ValueError( + "'self.cropper' must inherit LazyTrait if lazy is True " + f"'self.cropper' is of type({type(self.cropper)}" + ) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} + if isinstance(self.cropper, LazyTrait): + kwargs["lazy"] = lazy_ d[key] = self.cropper(d[key], **kwargs) # type: ignore return d @@ -385,6 +433,9 @@ class SpatialCropd(Cropd): - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`) - a spatial center and size - the start and end coordinates of the ROI + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ def __init__( @@ -396,6 +447,7 @@ def __init__( roi_end: Sequence[int] | None = None, roi_slices: Sequence[slice] | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -409,10 +461,10 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ - cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class CenterSpatialCropd(Cropd): @@ -422,6 +474,9 @@ class CenterSpatialCropd(Cropd): So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform @@ -431,11 +486,14 @@ class CenterSpatialCropd(Cropd): for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40]. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ - def __init__(self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False) -> None: - cropper = CenterSpatialCrop(roi_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + def __init__( + self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False, lazy: bool = False + ) -> None: + cropper = CenterSpatialCrop(roi_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class CenterScaleCropd(Cropd): @@ -444,19 +502,27 @@ class CenterScaleCropd(Cropd): Note: as using the same scaled ROI to crop, all the input data specified by `keys` should have the same spatial shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform roi_scale: specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5] or a number for all dims. If its components have non-positive values, will use `1.0` instead, which means the input image size. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ def __init__( - self, keys: KeysCollection, roi_scale: Sequence[float] | float, allow_missing_keys: bool = False + self, + keys: KeysCollection, + roi_scale: Sequence[float] | float, + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = CenterScaleCrop(roi_scale) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = CenterScaleCrop(roi_scale, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandSpatialCropd(RandCropd): @@ -470,6 +536,9 @@ class RandSpatialCropd(RandCropd): will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform @@ -487,6 +556,7 @@ class RandSpatialCropd(RandCropd): if True, the actual size is sampled from: `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") @@ -498,9 +568,10 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandScaleCropd(RandCropd): @@ -511,6 +582,9 @@ class RandScaleCropd(RandCropd): And allows to set the minimum and maximum scale of image size to limit the randomly generated ROI. Suppose all the expected fields specified by `keys` have same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform @@ -526,6 +600,7 @@ class RandScaleCropd(RandCropd): if True, the actual size is sampled from: `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ @deprecated_arg_default("random_size", True, False, since="1.1", replaced="1.3") @@ -537,9 +612,10 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size, lazy=lazy) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) class RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait): @@ -555,6 +631,9 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSa will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped results of several images may not have exactly the same shape. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform @@ -572,6 +651,7 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSa random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. Raises: ValueError: When ``num_samples`` is nonpositive. @@ -590,19 +670,25 @@ def __init__( random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.cropper = RandSpatialCropSamples(roi_size, num_samples, max_roi_size, random_center, random_size) + LazyTransform.__init__(self, lazy) + self.cropper = RandSpatialCropSamples( + roi_size, num_samples, max_roi_size, random_center, random_size, lazy=lazy + ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value def randomize(self, data: Any | None = None) -> None: self.sub_seed = self.R.randint(MAX_SEED, dtype="uint32") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: ret: list[dict[Hashable, torch.Tensor]] = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data for i in range(self.cropper.num_samples): @@ -611,9 +697,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, # for each key we reset the random state to ensure crops are the same self.randomize() + + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(dict(data)): self.cropper.set_random_state(seed=self.sub_seed) - for i, im in enumerate(self.cropper(data[key])): + for i, im in enumerate(self.cropper(data[key], lazy=lazy_)): ret[i][key] = im return ret @@ -629,6 +717,9 @@ class CropForegroundd(Cropd): - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields. Users can define arbitrary function to select expected foreground from the whole source image or specified channels. And it can also add margin to every dim of the bounding box of foreground object. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ def __init__( @@ -644,6 +735,7 @@ def __init__( start_coord_key: str = "foreground_start_coord", end_coord_key: str = "foreground_end_coord", allow_missing_keys: bool = False, + lazy: bool = False, **pad_kwargs, ) -> None: """ @@ -670,6 +762,7 @@ def __init__( start_coord_key: key to record the start coordinate of spatial bounding box for foreground. end_coord_key: key to record the end coordinate of spatial bounding box for foreground. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -683,17 +776,22 @@ def __init__( margin=margin, allow_smaller=allow_smaller, k_divisible=k_divisible, + lazy=lazy, **pad_kwargs, ) - super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys) + super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value + + @property + def requires_current_data(self): + return True - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: d = dict(data) self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) @@ -701,8 +799,10 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc d[self.start_coord_key] = box_start # type: ignore if self.end_coord_key is not None: d[self.end_coord_key] = box_end # type: ignore + + lazy_ = self.lazy if lazy is None else lazy for key, m in self.key_iterator(d, self.mode): - d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) + d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m, lazy=lazy_) return d @@ -710,6 +810,9 @@ class RandWeightedCropd(Randomizable, MapTransform, LazyTransform, MultiSampleTr """ Samples a list of `num_samples` image patches according to the provided `weight_map`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` @@ -719,6 +822,7 @@ class RandWeightedCropd(Randomizable, MapTransform, LazyTransform, MultiSampleTr If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. See Also: :py:class:`monai.transforms.RandWeightedCrop` @@ -733,10 +837,12 @@ def __init__( spatial_size: Sequence[int] | int, num_samples: int = 1, allow_missing_keys: bool = False, + lazy: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.w_key = w_key - self.cropper = RandWeightedCrop(spatial_size, num_samples) + self.cropper = RandWeightedCrop(spatial_size, num_samples, lazy=lazy) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -748,12 +854,14 @@ def set_random_state( def randomize(self, weight_map: NdarrayOrTensor) -> None: self.cropper.randomize(weight_map) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: # output starts as empty list of dictionaries ret: list = [dict(data) for _ in range(self.cropper.num_samples)] # deep copy all the unmodified data @@ -762,8 +870,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, ret[i][key] = deepcopy(data[key]) self.randomize(weight_map=data[self.w_key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(data): - for i, im in enumerate(self.cropper(data[key], randomize=False)): + for i, im in enumerate(self.cropper(data[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -783,6 +892,9 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, LazyTransform, MultiSam And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the valid crop ROI. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` @@ -813,6 +925,7 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, LazyTransform, MultiSam the requested ROI in any dimension. If `True`, any smaller dimensions will be set to match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -836,8 +949,10 @@ def __init__( bg_indices_key: str | None = None, allow_smaller: bool = False, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.label_key = label_key self.image_key = image_key self.fg_indices_key = fg_indices_key @@ -849,6 +964,7 @@ def __init__( num_samples=num_samples, image_threshold=image_threshold, allow_smaller=allow_smaller, + lazy=lazy, ) def set_random_state( @@ -867,12 +983,18 @@ def randomize( ) -> None: self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]: + @property + def requires_current_data(self): + return True + + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) fg_indices = d.pop(self.fg_indices_key, None) bg_indices = d.pop(self.bg_indices_key, None) @@ -886,8 +1008,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -936,6 +1059,9 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSa And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the valid crop ROI. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` @@ -964,7 +1090,7 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSa warn: if `True` prints a warning if a class is not present in the label. max_samples_per_class: maximum length of indices in each class to reduce memory consumption. Default is None, no subsampling. - + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. """ backend = RandCropByLabelClasses.backend @@ -984,8 +1110,10 @@ def __init__( allow_missing_keys: bool = False, warn: bool = True, max_samples_per_class: int | None = None, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy) self.label_key = label_key self.image_key = image_key self.indices_key = indices_key @@ -998,6 +1126,7 @@ def __init__( allow_smaller=allow_smaller, warn=warn, max_samples_per_class=max_samples_per_class, + lazy=lazy, ) def set_random_state( @@ -1012,12 +1141,16 @@ def randomize( ) -> None: self.cropper.randomize(label=label, indices=indices, image=image) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, value: bool) -> None: - self._lazy_evaluation = value - self.cropper.lazy_evaluation = value + @LazyTransform.lazy.setter # type: ignore + def lazy(self, value: bool) -> None: + self._lazy = value + self.cropper.lazy = value + + @property + def requires_current_data(self): + return True - def __call__(self, data: Mapping[Hashable, Any]) -> list[dict[Hashable, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, Any], lazy: bool | None = None) -> list[dict[Hashable, torch.Tensor]]: d = dict(data) self.randomize(d.get(self.label_key), d.pop(self.indices_key, None), d.get(self.image_key)) # type: ignore @@ -1028,8 +1161,9 @@ def __call__(self, data: Mapping[Hashable, Any]) -> list[dict[Hashable, torch.Te for key in set(d.keys()).difference(set(self.keys)): ret[i][key] = deepcopy(d[key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - for i, im in enumerate(self.cropper(d[key], randomize=False)): + for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)): ret[i][key] = im return ret @@ -1038,6 +1172,9 @@ class ResizeWithPadOrCropd(Padd): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform @@ -1053,6 +1190,7 @@ class ResizeWithPadOrCropd(Padd): allow_missing_keys: don't raise exception if key is missing. method: {``"symmetric"``, ``"end"``} Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. + lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1065,10 +1203,13 @@ def __init__( mode: SequenceStr = PytorchPadMode.CONSTANT, allow_missing_keys: bool = False, method: str = Method.SYMMETRIC, + lazy: bool = False, **pad_kwargs, ) -> None: - padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs) - super().__init__(keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys) # type: ignore + padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs, lazy=lazy) + super().__init__( + keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy # type: ignore + ) class BoundingRectd(MapTransform): diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index e694edb737..783635e467 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -27,14 +27,7 @@ from monai.data.utils import to_affine_nd from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import convert_pad_mode, create_translate -from monai.utils import ( - PytorchPadMode, - TraceKeys, - convert_to_dst_type, - convert_to_numpy, - convert_to_tensor, - ensure_tuple, -) +from monai.utils import PytorchPadMode, convert_to_dst_type, convert_to_numpy, convert_to_tensor, ensure_tuple __all__ = ["pad_nd", "pad_func", "crop_func", "crop_or_pad_nd"] @@ -161,11 +154,12 @@ def pad_func( to_pad: tuple[tuple[int, int]], transform_info: dict, mode: str = PytorchPadMode.CONSTANT, + lazy: bool = False, **kwargs, ) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according - to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + to ``lazy`` (default ``False``). `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, in which case `np.pad` will be used. @@ -181,6 +175,8 @@ def pad_func( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + lazy: a flag indicating whether the operation should be performed in a lazy fashion or not. + transform_info: a dictionary with the relevant information pertaining to an applied transform. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ @@ -205,24 +201,25 @@ def pad_func( extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out out = convert_to_tensor(out, track_meta=get_track_meta()) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor: +def crop_func(img: torch.Tensor, slices: tuple[slice, ...], lazy: bool, transform_info: dict) -> torch.Tensor: """ Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according - to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + to ``lazy`` (default ``False``). Args: img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim. slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`. + lazy: a flag indicating whether the operation should be performed in a lazy fashion or not. transform_info: a dictionary with the relevant information pertaining to an applied transform. """ img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -243,10 +240,10 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore out = out[slices] return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 96b8e6b782..ade9034563 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -23,8 +23,17 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd -from monai.transforms.transform import LazyTransform, Transform -from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor +from monai.transforms.traits import InvertibleTrait +from monai.transforms.transform import Transform +from monai.utils import ( + LazyAttr, + MetaKeys, + TraceKeys, + TraceStatusKeys, + convert_to_dst_type, + convert_to_numpy, + convert_to_tensor, +) __all__ = ["TraceableTransform", "InvertibleTransform"] @@ -77,13 +86,7 @@ def trace_key(key: Hashable = None): @staticmethod def transform_info_keys(): """The keys to store necessary info of an applied transform.""" - return ( - TraceKeys.CLASS_NAME, - TraceKeys.ID, - TraceKeys.TRACING, - TraceKeys.LAZY_EVALUATION, - TraceKeys.DO_TRANSFORM, - ) + return (TraceKeys.CLASS_NAME, TraceKeys.ID, TraceKeys.TRACING, TraceKeys.DO_TRANSFORM) def get_transform_info(self) -> dict: """ @@ -93,7 +96,6 @@ def get_transform_info(self) -> dict: self.__class__.__name__, id(self), self.tracing, - self.lazy_evaluation if isinstance(self, LazyTransform) else False, self._do_transform if hasattr(self, "_do_transform") else True, ) return dict(zip(self.transform_info_keys(), vals)) @@ -109,8 +111,8 @@ def push_transform(self, data, *args, **kwargs): set ``replace=True`` (default False) to rewrite the last transform infor in applied_operation/pending_operation based on ``self.get_transform_info()``. """ + lazy_eval = kwargs.get("lazy", False) transform_info = self.get_transform_info() - lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) kwargs = kwargs or {} replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info @@ -125,9 +127,9 @@ def push_transform(self, data, *args, **kwargs): xform.update(transform_info) else: # lazy, replace=True, do_transform=False xform, extra = transform_info, {} - meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=True, extra_info=extra) + meta_obj = self.push_transform(data, transform_info=xform, lazy=True, extra_info=extra) return data.copy_meta_from(meta_obj) - kwargs["lazy_evaluation"] = lazy_eval + kwargs["lazy"] = lazy_eval if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): kwargs["transform_info"].update(transform_info) else: @@ -145,7 +147,7 @@ def track_transform_meta( extra_info: dict | None = None, orig_size: tuple | None = None, transform_info=None, - lazy_evaluation=False, + lazy=False, ): """ Update a stack of applied/pending transforms metadata of ``data``. @@ -163,7 +165,7 @@ def track_transform_meta( orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. transform_info: info from self.get_transform_info(). - lazy_evaluation: whether to push the transform to pending_operations or applied_operations. + lazy: whether to push the transform to pending_operations or applied_operations. Returns: @@ -176,10 +178,10 @@ def track_transform_meta( if isinstance(data_t, MetaTensor): out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) - if lazy_evaluation and (not get_track_meta()): + if lazy and (not get_track_meta()): warnings.warn("metadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.") - if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + if not lazy and affine is not None and isinstance(data_t, MetaTensor): # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] @@ -202,6 +204,10 @@ def track_transform_meta( info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() elif hasattr(data_t, "shape"): info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] + + # add lazy status to the transform info + info[TraceKeys.LAZY] = lazy + # include extra_info if extra_info is not None: extra_info.pop(LazyAttr.SHAPE, None) @@ -209,7 +215,7 @@ def track_transform_meta( info[TraceKeys.EXTRA_INFO] = extra_info # push the transform info to the applied_operation or pending_operation stack - if lazy_evaluation: + if lazy: if sp_size is None: if LazyAttr.SHAPE not in info: info[LazyAttr.SHAPE] = info.get(TraceKeys.ORIG_SIZE, []) @@ -227,17 +233,18 @@ def track_transform_meta( if out_obj.pending_operations: transform_name = info.get(TraceKeys.CLASS_NAME, "") if isinstance(info, dict) else "" msg = ( - f"Applying transform {transform_name} to a MetaTensor with pending operations " - "is not supported (as this eventually changes the ordering of applied_operations when the pending " - f"operations are executed). Please clear the pending operations before transform {transform_name}." - f"\nPending operations: {[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}." + f"Transform {transform_name} has been applied to a MetaTensor with pending operations: " + f"{[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}" ) + if key is not None: + msg += f" for key {key}" + pend = out_obj.pending_operations[-1] - if not isinstance(pend.get(TraceKeys.EXTRA_INFO), dict): - pend[TraceKeys.EXTRA_INFO] = dict(pend.get(TraceKeys.EXTRA_INFO, {})) - if not isinstance(info.get(TraceKeys.EXTRA_INFO), dict): - info[TraceKeys.EXTRA_INFO] = dict(info.get(TraceKeys.EXTRA_INFO, {})) - info[TraceKeys.EXTRA_INFO]["warn"] = pend[TraceKeys.EXTRA_INFO]["warn"] = msg + statuses = pend.get(TraceKeys.STATUSES, dict()) + messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, list()) + messages.append(msg) + statuses[TraceStatusKeys.PENDING_DURING_APPLY] = messages + info[TraceKeys.STATUSES] = statuses out_obj.push_applied_operation(info) if isinstance(data, Mapping): if not isinstance(data, dict): @@ -329,7 +336,7 @@ def trace_transform(self, to_trace: bool): self.tracing = prev -class InvertibleTransform(TraceableTransform): +class InvertibleTransform(TraceableTransform, InvertibleTrait): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py index 02349dd0f2..1e97f89407 100644 --- a/monai/transforms/lazy/__init__.py +++ b/monai/transforms/lazy/__init__.py @@ -8,8 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -from .functional import apply_transforms -from .utils import combine_transforms, resample diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..6e6797c6e1 --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.transforms.traits import InvertibleTrait + +__all__ = ["ApplyPending"] + + +class ApplyPending(InvertibleTrait): + """ + ApplyPending can be inserted into a pipeline that is being executed lazily in order to ensure + resampling happens before the next transform. It doesn't do anything itself, but its presence + causes the pipeline to be executed as ApplyPending doesn't implement ```LazyTrait``. + + See ``Compose`` for a detailed explanation of the lazy resampling feature. + """ + + def __call__(self, data): + return data + + def inverse(self, data): + return data diff --git a/monai/transforms/lazy/dictionary.py b/monai/transforms/lazy/dictionary.py new file mode 100644 index 0000000000..384fc9bf98 --- /dev/null +++ b/monai/transforms/lazy/dictionary.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.transforms.traits import InvertibleTrait +from monai.transforms.transform import MapTransform + +__all__ = ["ApplyPendingd", "ApplyPendingD", "ApplyPendingDict"] + + +class ApplyPendingd(InvertibleTrait, MapTransform): + """ + ApplyPendingd can be inserted into a pipeline that is being executed lazily in order + to ensure resampling happens before the next transform. It doesn't do anything itself, + but its presence causes the pipeline to be executed as it doesn't implement ``LazyTrait`` + + See ``Compose`` for a detailed explanation of the lazy resampling feature. + + Args: + keys: the keys for tensors that should have their pending transforms executed + """ + + def __init__(self, keys: KeysCollection): + super().__init__(keys) + + def __call__(self, data): + return data + + def inverse(self, data): + return data + + +ApplyPendingD = ApplyPendingDict = ApplyPendingd diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 22c74cef8a..6b95027832 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,10 +11,12 @@ from __future__ import annotations -from typing import Any +from typing import Any, Mapping, Sequence import torch +from monai.apps.utils import get_logger +from monai.config import NdarrayOrTensor from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( @@ -24,19 +26,196 @@ kwargs_from_pending, resample, ) +from monai.transforms.traits import LazyTrait +from monai.transforms.transform import MapTransform from monai.utils import LazyAttr, look_up_option -__all__ = ["apply_transforms"] +__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending"] __override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"} -def apply_transforms( - data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None, **kwargs: Any +def _log_pending_info( + transform: Any, + data: Any, + activity: str, + *, + lazy: bool | None = None, + key: str | None = None, + logger_name: bool | str = False, ): + if logger_name is False: + return + logger_name = logger_name if isinstance(logger_name, str) else "apply_pending_transforms" + logger = get_logger(logger_name) + + tcname = type(transform).__name__ + if isinstance(transform, LazyTrait): + tlazy = f", transform.lazy: {transform.lazy}" + if lazy is not None and lazy != transform.lazy: + tlazy += " (overridden)" + else: + tlazy = ", transform is not lazy" + + msg = f"{activity} - lazy: {lazy}, {{key_msg}}pending: {{pcount}}, upcoming '{tcname}'{tlazy}" + + if isinstance(transform, MapTransform): + transform_keys = transform.keys if key is None else (key,) + for k in transform_keys: + if k in data: + pcount = len(data[k].pending_operations) if isinstance(data[k], MetaTensor) else 0 + logger.info(msg.format(pcount=pcount, key_msg=f"key: '{k}', ")) + else: + pcount = len(data.pending_operations) if isinstance(data, MetaTensor) else 0 + logger.info(msg.format(pcount=pcount, key_msg="" if key is None else f"key: '{key}', ")) + + +def _log_applied_info(data: Any, key=None, logger_name: bool | str = False): + if logger_name is False: + return + logger_name = logger_name if isinstance(logger_name, str) else "apply_pending_transforms" + logger = get_logger(logger_name) + + key_str = "" if key is None else f"key: '{key}', " + logger.info(f"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}") + + +def apply_pending_transforms( + data: NdarrayOrTensor | Sequence[Any | NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], + keys: tuple | None, + overrides: dict | None = None, + logger_name: bool | str = False, +): + """ + apply_pending_transforms is called with either a tensor or a dictionary, some entries of which contain + tensors. + + When operating on a dictionary of tensors, the 'keys' parameter determines what tensors should be checked. + If 'keys' is not set, all keys of 'data' are considered. + + This method optionally takes a set of overrides that can be used to change specific parameters on the + transform pipeline. See ``Compose`` for more details. This method takes a logger_name that can be used + to override the default logger, to provide telemetry during the execution of pending transforms. + + This method is intended primarily for use by ``execute_compose`` and other methods that handle the + underlying execution of transform pipelines. You should not need to use it in the general case, unless + you are developing functionality to perform such operations. + + Args: + data: a ``torch.Tensor`` or ``MetaTensor``, or dictionary of tensors. + keys: an optional tuple of keys that filters the keys on 'data' if it is a dict + overrides: An optional dictionary that specifies parameters that can be used to override transform + arguments when they are called. When 'data' is a dict, this dictionary should contain a dictionary + of overrides for each key that needs them + logger_name: An optional name for a logger to be used when applying pending transforms. If None, + logging is suppressed. + Returns: + an object of the same type as data if pending transforms were applied, or 'data' if they were not + """ + if isinstance(data, list): + return [apply_pending_transforms(d, keys, overrides, logger_name) for d in data] + if isinstance(data, tuple): + return tuple(apply_pending_transforms(d, keys, overrides, logger_name) for d in data) + + if isinstance(data, dict): + # get the keys from 'data' for metatensors with pending operations. If 'keys' is set, select + # only data keys that are in 'keys' + active_keys = [k for k in data.keys() if keys is None or k in keys] + keys_to_update = [k for k in active_keys if isinstance(data[k], MetaTensor) and data[k].has_pending_operations] + + if len(keys_to_update) > 0: + rdata = dict(data) + + for k in keys_to_update: + overrides_ = None if overrides is None else overrides.get(k, None) + rdata[k], _ = apply_pending(data[k], overrides=overrides_) + _log_applied_info(rdata[k], key=k, logger_name=logger_name) + + return rdata + else: + if isinstance(data, MetaTensor) and data.has_pending_operations: + rdata, _ = apply_pending(data, overrides=overrides) + _log_applied_info(rdata, logger_name=logger_name) + return rdata + + return data + + +def apply_pending_transforms_in_order( + transform, data, lazy: bool | None = None, overrides: dict | None = None, logger_name: bool | str = False +): + """ + This method causes "in order" processing of pending transforms to occur. + "in order" processing of pending transforms ensures that all pending transforms have been applied to the + tensor before a non-lazy transform (or lazy transform that is executing non-lazily) is carried out. + It ensures that no operations will be added to a metatensor's apply_operations while there are outstanding + pending_operations. Note that there is only one mechanism for executing lazy resampling at present but this + is expected to change in future releases. + + Evaluation of pending transforms is performed under the following circumstances: + * If the transform is a lazy transform and: + * The transform checks data as part of its execution, or + * the transform is not executing lazily + * If the transform is an ApplyPending[d] transform + * If the transform is not a lazy transform + + This method is designed to be used only in the context of implementing lazy resampling functionality. In general + you should not need to interact with or use this method directly, and its API may change without warning between + releases. See the :ref:`Lazy Resampling topic for more information about lazy resampling. + + Args: + transform: a transform that should be evaluated to determine whether pending transforms should be applied + data: a tensor / MetaTensor, or dictionary containing tensors / MetaTensors whose pending transforms may + need to be applied + lazy: The lazy mode that is being applied (this can be False, True or None) + overrides: An optional dictionary containing overrides to be applied to the pending transforms when they + are lazily executed. If data is a dict, it should contain a dictionary of overrides for each key that + needs them + logger_name: An optional name for a logger to be used when applying pending transforms. If None, + logging is suppressed. + Returns: + an object of the same type as data if pending transforms were applied, or 'data' if they were not + + """ + from monai.transforms.lazy.dictionary import ApplyPendingd + + must_apply_pending = True + keys = transform.keys if isinstance(transform, ApplyPendingd) else None + if isinstance(transform, LazyTrait) and not transform.requires_current_data: + must_apply_pending = not (transform.lazy if lazy is None else lazy) + + if must_apply_pending is True: + _log_pending_info(transform, data, "Apply pending transforms", lazy=lazy, logger_name=logger_name) + return apply_pending_transforms(data, keys, overrides, logger_name) + + _log_pending_info(transform, data, "Accumulate pending transforms", lazy=lazy, logger_name=logger_name) + return data + + +def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None): """ This method applies pending transforms to `data` tensors. - Currently, only 2d and 3d input are supported. + Currently, only 2d and 3d inputs are supported. + + This method is designed to be called by ``apply_pending_transforms`` and other methods / classes + that are part of the implementation of lazy resampling. In general, you should not need to call + this method unless you are directly developing custom lazy execution strategies. + + It works by calculating the overall effect of the accumulated pending transforms. When it runs + out of pending transforms or when it finds incompatibilities between the accumulated pending + transform and the next pending transform, it then applies the accumulated transform in a call to + ``resample``. + + Pending transforms are incompatible with each other if one or more of the arguments in the pending + transforms differ. These are parameters such as 'mode', 'padding_mode', 'dtype' and so forth. If + a pending transform doesn't have a given parameter, it is considered compatible with the + accumulated transform. If a subsequent transform has a parameter that is incompatible with + the accumulated transform (e.g. 'mode' of 'bilinear' vs. 'mode' of 'nearest'), an intermediate + resample will be performed and the accumulated transform reset to its starting state. + + After resampling, the pending transforms are pushed to the ``applied_transforms`` field of the + resulting MetaTensor. Note, if a torch.tensor is passed to this method along with a list of + pending transforms, the resampled tensor will be wrapped in a MetaTensor before being returned. Args: data: A torch Tensor or a monai MetaTensor. @@ -63,10 +242,8 @@ def apply_transforms( - device: device for resampling computation. Defaults to ``None``. - resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the :py:class:`monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad). - """ overrides = (overrides or {}).copy() - overrides.update((kwargs or {}).copy()) for k in overrides: look_up_option(k, __override_keywords) # check existence of the key @@ -103,9 +280,11 @@ def apply_transforms( _cur_kwargs = cur_kwargs.copy() _cur_kwargs.update(override_kwargs) data = resample(data.to(device), cumulative_xform, _cur_kwargs) + next_matrix = affine_from_pending(p) if next_matrix.shape[0] == 3: next_matrix = to_affine_nd(3, next_matrix) + cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) cur_kwargs.update(override_kwargs) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index fa1bb6d48e..359559e319 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -224,6 +224,6 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = return img resampler = monai.transforms.SpatialResample(**init_kwargs) - resampler.lazy_evaluation = False # resampler is a lazytransform + resampler.lazy = False # resampler is a lazytransform with resampler.trace_transform(False): # don't track this transform in `img` return resampler(img=img, **call_kwargs) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6fe433a0bc..cde51724f8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -124,6 +124,9 @@ class SpatialResample(InvertibleTransform, LazyTransform): Internally this transform computes the affine transform matrix from ``src_affine`` to ``dst_affine``, by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY] @@ -134,6 +137,7 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, + lazy: bool = False, ): """ Args: @@ -152,7 +156,10 @@ def __init__( dtype: data type for resampling computation. Defaults to ``float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.mode = mode self.padding_mode = padding_mode self.align_corners = align_corners @@ -167,6 +174,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -198,7 +206,9 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``. When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, @@ -210,8 +220,17 @@ def __call__( align_corners = align_corners if align_corners is not None else self.align_corners mode = mode if mode is not None else self.mode padding_mode = padding_mode if padding_mode is not None else self.padding_mode + lazy_ = self.lazy if lazy is None else lazy return spatial_resample( - img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, self.get_transform_info() + img, + dst_affine, + spatial_size, + mode, + padding_mode, + align_corners, + dtype_pt, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -233,8 +252,13 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: class ResampleToMatch(SpatialResample): - """Resample an image to match given metadata. The affine matrix will be aligned, - and the size of the output image will match.""" + """ + Resample an image to match given metadata. The affine matrix will be aligned, + and the size of the output image will match. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + """ def __call__( # type: ignore self, @@ -244,6 +268,7 @@ def __call__( # type: ignore padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -267,6 +292,10 @@ def __call__( # type: ignore dtype: data type for resampling computation. Defaults to ``self.dtype`` or ``np.float64`` (for best precision). If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + Raises: ValueError: When the affine matrix of the source image is not invertible. Returns: @@ -275,6 +304,7 @@ def __call__( # type: ignore if img_dst is None: raise RuntimeError("`img_dst` is missing.") dst_affine = img_dst.peek_pending_affine() if isinstance(img_dst, MetaTensor) else torch.eye(4) + lazy_ = self.lazy if lazy is None else lazy img = super().__call__( img=img, dst_affine=dst_affine, @@ -283,8 +313,9 @@ def __call__( # type: ignore padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) - if not self.lazy_evaluation: + if not lazy_: if isinstance(img, MetaTensor): img.affine = dst_affine if isinstance(img_dst, MetaTensor): @@ -305,6 +336,9 @@ def __call__( # type: ignore class Spacing(InvertibleTransform, LazyTransform): """ Resample input image into the specified `pixdim`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = SpatialResample.backend @@ -321,6 +355,7 @@ def __init__( recompute_affine: bool = False, min_pixdim: Sequence[float] | float | np.ndarray | None = None, max_pixdim: Sequence[float] | float | np.ndarray | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -373,8 +408,10 @@ def __init__( max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the value of `pixdim`. Default to `None`. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64) self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64) @@ -387,13 +424,13 @@ def __init__( raise ValueError(f"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.") self.sp_resample = SpatialResample( - mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype + mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.sp_resample.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.sp_resample.lazy = val @deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.") def __call__( @@ -406,6 +443,7 @@ def __call__( dtype: DtypeLike = None, scale_extent: bool | None = None, output_spatial_shape: Sequence[int] | np.ndarray | int | None = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -435,6 +473,9 @@ def __call__( output_spatial_shape: specify the shape of the output data_array. This is typically useful for the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization error with the affine. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``data_array`` has no spatial dimensions. @@ -485,6 +526,7 @@ def __call__( new_affine[:sr, -1] = offset[:sr] actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape + lazy_ = self.lazy if lazy is None else lazy data_array = self.sp_resample( data_array, dst_affine=torch.as_tensor(new_affine), @@ -493,9 +535,10 @@ def __call__( padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) if self.recompute_affine and isinstance(data_array, MetaTensor): - if self.lazy_evaluation: + if lazy_: raise NotImplementedError("recompute_affine is not supported with lazy evaluation.") a = scale_affine(original_spatial_shape, actual_shape) data_array.affine = convert_to_dst_type(a, affine_)[0] # type: ignore @@ -508,6 +551,9 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: class Orientation(InvertibleTransform, LazyTransform): """ Change the input image's orientation into the specified based on `axcodes`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = [TransformBackends.NUMPY, TransformBackends.TORCH] @@ -517,6 +563,7 @@ def __init__( axcodes: str | None = None, as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), + lazy: bool = False, ) -> None: """ Args: @@ -529,6 +576,8 @@ def __init__( labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False Raises: ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values. @@ -536,6 +585,7 @@ def __init__( See Also: `nibabel.orientations.ornt2axcodes`. """ + LazyTransform.__init__(self, lazy=lazy) if axcodes is None and not as_closest_canonical: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") if axcodes is not None and as_closest_canonical: @@ -544,13 +594,16 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels - def __call__(self, data_array: torch.Tensor) -> torch.Tensor: + def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. If input type is `torch.Tensor`, original affine is assumed to be identity. Args: data_array: in shape (num_channels, H[, W, ...]). + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``data_array`` has no spatial dimensions. @@ -595,7 +648,10 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" ) spatial_ornt = nib.orientations.ornt_transform(src, dst) - return orientation(data_array, affine_np, spatial_ornt, self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return orientation( + data_array, affine_np, spatial_ornt, lazy=lazy_, transform_info=self.get_transform_info() + ) # type: ignore[no-any-return] def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -616,27 +672,37 @@ class Flip(InvertibleTransform, LazyTransform): See `torch.flip` documentation for additional details: https://pytorch.org/docs/stable/generated/torch.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_axis: spatial axes along which to flip over. Default is None. The default `axis=None` will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] - def __init__(self, spatial_axis: Sequence[int] | int | None = None) -> None: + def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: + LazyTransform.__init__(self, lazy=lazy) self.spatial_axis = spatial_axis - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - return flip(img, self.spatial_axis, transform_info=self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) @@ -650,6 +716,9 @@ class Resize(InvertibleTransform, LazyTransform): Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: spatial_size: expected shape of spatial dimensions after resize operation. if some components of the `spatial_size` are non-positive values, the transform will use the @@ -677,6 +746,8 @@ class Resize(InvertibleTransform, LazyTransform): anti-aliasing is performed prior to rescaling. dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -690,7 +761,9 @@ def __init__( anti_aliasing: bool = False, anti_aliasing_sigma: Sequence[float] | float | None = None, dtype: DtypeLike | torch.dtype = torch.float32, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size self.mode = mode @@ -707,6 +780,7 @@ def __call__( anti_aliasing: bool | None = None, anti_aliasing_sigma: Sequence[float] | float | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -729,7 +803,9 @@ def __call__( anti-aliasing is performed prior to rescaling. dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. @@ -760,6 +836,7 @@ def __call__( _mode = self.mode if mode is None else mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + lazy_ = self.lazy if lazy is None else lazy return resize( # type: ignore img, sp_size, @@ -769,6 +846,7 @@ def __call__( input_ndim, anti_aliasing, anti_aliasing_sigma, + lazy_, self.get_transform_info(), ) @@ -798,6 +876,9 @@ class Rotate(InvertibleTransform, LazyTransform): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. keep_size: If it is True, the output shape is kept the same as the input. @@ -814,6 +895,8 @@ class Rotate(InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -826,7 +909,9 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike | torch.dtype = torch.float32, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.angle = angle self.keep_size = keep_size self.mode: str = mode @@ -841,6 +926,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -858,6 +944,9 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -870,8 +959,17 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] output_shape = im_shape if self.keep_size else None + lazy_ = self.lazy if lazy is None else lazy return rotate( # type: ignore - img, self.angle, output_shape, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() + img, + self.angle, + output_shape, + _mode, + _padding_mode, + _align_corners, + _dtype, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -914,6 +1012,9 @@ class Zoom(InvertibleTransform, LazyTransform): Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors as input, and provides an option of preserving the input spatial size. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: zoom: The zoom factor along the spatial axes. If a float, zoom is the same for each spatial axis. @@ -934,9 +1035,10 @@ class Zoom(InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. keep_size: Should keep original size (padding/slicing if needed), default is True. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. - """ backend = [TransformBackends.TORCH] @@ -949,8 +1051,10 @@ def __init__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, + lazy: bool = False, **kwargs, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.zoom = zoom self.mode = mode self.padding_mode = padding_mode @@ -966,6 +1070,7 @@ def __call__( padding_mode: str | None = None, align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -986,7 +1091,9 @@ def __call__( See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim @@ -994,8 +1101,17 @@ def __call__( _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + lazy_ = self.lazy if lazy is None else lazy return zoom( # type: ignore - img, _zoom, self.keep_size, _mode, _padding_mode, _align_corners, _dtype, self.get_transform_info() + img, + _zoom, + self.keep_size, + _mode, + _padding_mode, + _align_corners, + _dtype, + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1030,32 +1146,41 @@ class Rotate90(InvertibleTransform, LazyTransform): See `torch.rot90` for additional details: https://pytorch.org/docs/stable/generated/torch.rot90.html#torch-rot90. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = [TransformBackends.TORCH] - def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1)) -> None: + def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False) -> None: """ Args: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.") self.spatial_axes = spatial_axes_ - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axes) - return rotate90(img, axes, self.k, self.get_transform_info()) # type: ignore + lazy_ = self.lazy if lazy is None else lazy + return rotate90(img, axes, self.k, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -1074,11 +1199,16 @@ class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Rotate90.backend - def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1)) -> None: + def __init__( + self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False + ) -> None: """ Args: prob: probability of rotating. @@ -1086,8 +1216,11 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, i max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`, (Default 3). spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.max_k = max_k self.spatial_axes = spatial_axes @@ -1099,23 +1232,27 @@ def randomize(self, data: Any | None = None) -> None: return None self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ + if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: - xform = Rotate90(self._rand_k, self.spatial_axes) - xform.lazy_evaluation = self.lazy_evaluation + xform = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) out = xform(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1130,6 +1267,9 @@ class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly rotate the input arrays. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: range_x: Range of rotation angle in radians in the plane defined by the first and second axes. If single number, angle is uniformly sampled from (-range_x, range_x). @@ -1152,6 +1292,8 @@ class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Rotate.backend @@ -1167,8 +1309,10 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike | torch.dtype = np.float32, + lazy: bool = False, ) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -1205,6 +1349,7 @@ def __call__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, randomize: bool = True, + lazy: bool | None = None, ): """ Args: @@ -1221,10 +1366,14 @@ def __call__( If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: ndim = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]) rotator = Rotate( @@ -1234,12 +1383,12 @@ def __call__( padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, + lazy=lazy_, ) - rotator.lazy_evaluation = self.lazy_evaluation out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1255,33 +1404,43 @@ class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): See numpy.flip for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None) -> None: + def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: RandomizableTransform.__init__(self, prob) - self.flipper = Flip(spatial_axis=spatial_axis) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize(None) - out = self.flipper(img) if self._do_transform else img + lazy_ = self.lazy if lazy is None else lazy + out = self.flipper(img, lazy=lazy_) if self._do_transform else img out = convert_to_tensor(out, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1298,22 +1457,27 @@ class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): See numpy.flip for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: prob: Probability of flipping. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend - def __init__(self, prob: float = 0.1) -> None: + def __init__(self, prob: float = 0.1, lazy: bool = False) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self._axis: int | None = None self.flipper = Flip(spatial_axis=self._axis) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) @@ -1321,21 +1485,25 @@ def randomize(self, data: NdarrayOrTensor) -> None: return None self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: + def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) randomize: whether to execute `randomize()` function first, default to True. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize(data=img) + lazy_ = self.lazy if lazy is None else lazy if self._do_transform: self.flipper.spatial_axis = self._axis - out = self.flipper(img) + out = self.flipper(img, lazy=lazy_) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1351,6 +1519,9 @@ class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly zooms input arrays with given probability within given zoom range. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: prob: Probability of zooming. min_zoom: Min zoom factor. Can be float or sequence same size as image. @@ -1379,6 +1550,8 @@ class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1396,9 +1569,11 @@ def __init__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = torch.float32, keep_size: bool = True, + lazy: bool = False, **kwargs, ) -> None: RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): @@ -1434,6 +1609,7 @@ def __call__( align_corners: bool | None = None, dtype: DtypeLike | torch.dtype = None, randomize: bool = True, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -1454,12 +1630,15 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. randomize: whether to execute `randomize()` function first, default to True. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ # match the spatial image dim if randomize: self.randomize(img=img) + lazy_ = self.lazy if lazy is None else lazy if not self._do_transform: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) else: @@ -1470,11 +1649,11 @@ def __call__( padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype, + lazy=lazy_, **self.kwargs, ) - xform.lazy_evaluation = self.lazy_evaluation out = xform(img) - self.push_transform(out, replace=True) + self.push_transform(out, replace=True, lazy=lazy_) return out # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1488,6 +1667,9 @@ class AffineGrid(LazyTransform): """ Affine transforms on the coordinates. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D. Defaults to no rotation. @@ -1513,7 +1695,8 @@ class AffineGrid(LazyTransform): affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = [TransformBackends.TORCH] @@ -1528,7 +1711,9 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, affine: NdarrayOrTensor | None = None, + lazy: bool = False, ) -> None: + LazyTransform.__init__(self, lazy=lazy) self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params @@ -1540,7 +1725,7 @@ def __init__( self.affine = affine def __call__( - self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None + self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None, lazy: bool | None = None ) -> tuple[torch.Tensor | None, torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. @@ -1550,12 +1735,15 @@ def __call__( Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if not self.lazy_evaluation: + lazy_ = self.lazy if lazy is None else lazy + if not lazy_: if grid is None: # create grid from spatial_size if spatial_size is None: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1584,7 +1772,7 @@ def __call__( else: affine = self.affine # type: ignore affine = to_affine_nd(spatial_dims, affine) - if self.lazy_evaluation: + if lazy_: return None, affine affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore @@ -1603,6 +1791,8 @@ class RandAffineGrid(Randomizable, LazyTransform): """ Generate randomised affine grid. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = AffineGrid.backend @@ -1615,6 +1805,7 @@ def __init__( scale_range: RandRange = None, device: torch.device | None = None, dtype: DtypeLike = np.float32, + lazy: bool = False, ) -> None: """ Args: @@ -1643,6 +1834,8 @@ def __init__( device: device to store the output grid data. dtype: data type for the grid computation. Defaults to ``np.float32``. If ``None``, use the data type of input data (if `grid` is provided). + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1651,6 +1844,7 @@ def __init__( - :py:meth:`monai.transforms.utils.create_scale` """ + LazyTransform.__init__(self, lazy=lazy) self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) self.translate_range = ensure_tuple(translate_range) @@ -1683,19 +1877,27 @@ def randomize(self, data: Any | None = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Sequence[int] | None = None, grid: NdarrayOrTensor | None = None, randomize: bool = True + self, + spatial_size: Sequence[int] | None = None, + grid: NdarrayOrTensor | None = None, + randomize: bool = True, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. randomize: boolean as to whether the grid parameters governing the grid should be randomized. + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. """ if randomize: self.randomize() + lazy_ = self.lazy if lazy is None else lazy affine_grid = AffineGrid( rotate_params=self.rotate_params, shear_params=self.shear_params, @@ -1703,9 +1905,9 @@ def __call__( scale_params=self.scale_params, device=self.device, dtype=self.dtype, + lazy=lazy_, ) - affine_grid.lazy_evaluation = self.lazy_evaluation - if self.lazy_evaluation: # return the affine only, don't construct the grid + if lazy_: # return the affine only, don't construct the grid self.affine = affine_grid(spatial_size, grid)[1] # type: ignore return None # type: ignore _grid: torch.Tensor @@ -1924,6 +2126,8 @@ class Affine(InvertibleTransform, LazyTransform): Transform ``img`` given the affine parameters. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = list(set(AffineGrid.backend) & set(Resample.backend)) @@ -1943,6 +2147,7 @@ def __init__( dtype: DtypeLike = np.float32, align_corners: bool = False, image_only: bool = False, + lazy: bool = False, ) -> None: """ The affine transformations are applied in rotate, shear, translate, scale order. @@ -1997,8 +2202,10 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html image_only: if True return only the image volume, otherwise return (image, affine). - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ + LazyTransform.__init__(self, lazy=lazy) self.affine_grid = AffineGrid( rotate_params=rotate_params, shear_params=shear_params, @@ -2008,6 +2215,7 @@ def __init__( dtype=dtype, align_corners=align_corners, device=device, + lazy=lazy, ) self.image_only = image_only self.norm_coord = not normalized @@ -2016,10 +2224,10 @@ def __init__( self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self.affine_grid.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self.affine_grid.lazy = val + self._lazy = val def __call__( self, @@ -2027,6 +2235,7 @@ def __call__( spatial_size: Sequence[int] | int | None = None, mode: str | int | None = None, padding_mode: str | None = None, + lazy: bool | None = None, ) -> torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor]: """ Args: @@ -2048,13 +2257,17 @@ def __call__( When `mode` is an integer, using numpy/cupy backends, this argument accepts {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) + lazy_ = self.lazy if lazy is None else lazy _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode - grid, affine = self.affine_grid(spatial_size=sp_size) + grid, affine = self.affine_grid(spatial_size=sp_size, lazy=lazy_) return affine_func( # type: ignore img, @@ -2066,7 +2279,8 @@ def __call__( _padding_mode, True, self.image_only, - self.get_transform_info(), + lazy=lazy_, + transform_info=self.get_transform_info(), ) @classmethod @@ -2109,6 +2323,8 @@ class RandAffine(RandomizableTransform, InvertibleTransform, LazyTransform): Random affine transform. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Affine.backend @@ -2125,6 +2341,7 @@ def __init__( padding_mode: str = GridSamplePadMode.REFLECTION, cache_grid: bool = False, device: torch.device | None = None, + lazy: bool = False, ) -> None: """ Args: @@ -2174,6 +2391,8 @@ def __init__( If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. device: device on which the tensor will be allocated. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`RandAffineGrid` for the random affine parameters configurations. @@ -2181,32 +2400,33 @@ def __init__( """ RandomizableTransform.__init__(self, prob) - + LazyTransform.__init__(self, lazy=lazy) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, device=device, + lazy=lazy, ) self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid - self._cached_grid = self._init_identity_cache() + self._cached_grid = self._init_identity_cache(lazy) self.mode = mode self.padding_mode: str = padding_mode - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rand_affine_grid.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rand_affine_grid.lazy = val - def _init_identity_cache(self): + def _init_identity_cache(self, lazy: bool): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ - if self.lazy_evaluation: + if lazy: return None if self.spatial_size is None: if self.cache_grid: @@ -2226,14 +2446,14 @@ def _init_identity_cache(self): return None return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend="torch") - def get_identity_grid(self, spatial_size: Sequence[int]): + def get_identity_grid(self, spatial_size: Sequence[int], lazy: bool): """ Return a cached or new identity grid depends on the availability. Args: spatial_size: non-dynamic spatial size """ - if self.lazy_evaluation: + if lazy: return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( @@ -2265,6 +2485,7 @@ def __call__( padding_mode: str | None = None, randomize: bool = True, grid=None, + lazy: bool | None = None, ) -> torch.Tensor: """ Args: @@ -2288,7 +2509,9 @@ def __call__( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html randomize: whether to execute `randomize()` function first, default to True. grid: precomputed grid to be used (mainly to accelerate `RandAffined`). - + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. """ if randomize: self.randomize() @@ -2299,17 +2522,18 @@ def __call__( do_resampling = self._do_transform or (sp_size != ensure_tuple(ori_size)) _mode = mode if mode is not None else self.mode _padding_mode = padding_mode if padding_mode is not None else self.padding_mode + lazy_ = self.lazy if lazy is None else lazy img = convert_to_tensor(img, track_meta=get_track_meta()) - if self.lazy_evaluation: + if lazy_: if self._do_transform: affine = self.rand_affine_grid.get_transformation_matrix() else: affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] else: if grid is None: - grid = self.get_identity_grid(sp_size) + grid = self.get_identity_grid(sp_size, lazy_) if self._do_transform: - grid = self.rand_affine_grid(grid=grid, randomize=randomize) + grid = self.rand_affine_grid(grid=grid, randomize=randomize, lazy=lazy_) affine = self.rand_affine_grid.get_transformation_matrix() return affine_func( # type: ignore img, @@ -2321,7 +2545,8 @@ def __call__( _padding_mode, do_resampling, True, - self.get_transform_info(), + lazy=lazy_, + transform_info=self.get_transform_info(), ) def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -2436,6 +2661,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, + lazy=False, ) self.resampler = Resample(device=device) @@ -2603,6 +2829,7 @@ def __init__( translate_range=translate_range, scale_range=scale_range, device=device, + lazy=False, ) self.resampler = Resample(device=device) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2f34f57ca2..4ba5849c46 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -154,6 +154,9 @@ class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform): changes) in the dictionary so that ``src_affine`` always refers to the current status of affine. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + See also: :py:class:`monai.transforms.SpatialResample` """ @@ -169,6 +172,7 @@ def __init__( dtype: Sequence[DtypeLike] | DtypeLike = np.float64, dst_keys: KeysCollection | None = "dst_affine", allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -196,21 +200,37 @@ def __init__( It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False. """ - super().__init__(keys, allow_missing_keys) - self.sp_transform = SpatialResample() + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.sp_transform = SpatialResample(lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.sp_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.sp_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d: dict = dict(data) for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.dst_keys @@ -223,6 +243,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) return d @@ -234,7 +255,13 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch class ResampleToMatchd(MapTransform, InvertibleTransform, LazyTransform): - """Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.""" + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + + """ backend = ResampleToMatch.backend @@ -247,6 +274,7 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike] | DtypeLike = np.float64, allow_missing_keys: bool = False, + lazy: bool = False, ): """ Args: @@ -274,21 +302,37 @@ def __init__( the output data type is always ``float32``. It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.key_dst = key_dst self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.resampler = ResampleToMatch() + self.resampler = ResampleToMatch(lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.resampler.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.resampler.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -300,6 +344,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + lazy=lazy_, ) return d @@ -320,6 +365,9 @@ class Spacingd(MapTransform, InvertibleTransform, LazyTransform): After resampling the input array, this transform will write the new affine to the `affine` field of metadata which is formed by ``key_{meta_key_postfix}``. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + see also: :py:class:`monai.transforms.Spacing` """ @@ -341,6 +389,7 @@ def __init__( max_pixdim: Sequence[float] | float | None = None, ensure_same_shape: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -400,11 +449,18 @@ def __init__( ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim, whether to ensure exactly the same output spatial shape. Default to True. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.spacing_transform = Spacing( - pixdim, diagonal=diagonal, recompute_affine=recompute_affine, min_pixdim=min_pixdim, max_pixdim=max_pixdim + pixdim, + diagonal=diagonal, + recompute_affine=recompute_affine, + min_pixdim=min_pixdim, + max_pixdim=max_pixdim, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -413,16 +469,29 @@ def __init__( self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys)) self.ensure_same_shape = ensure_same_shape - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.spacing_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.spacing_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d: dict = dict(data) _init_shape, _pixdim, should_match = None, None, False output_shape_k = None # tracking output shape + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent @@ -442,6 +511,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc dtype=dtype, scale_extent=scale_extent, output_spatial_shape=output_shape_k if should_match else None, + lazy=lazy_, ) output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:] return d @@ -460,6 +530,9 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform): This transform assumes the channel-first input format. In the case of using this transform for normalizing the orientations of images, it should be used before any anisotropic spatial transforms. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Orientation.backend @@ -471,6 +544,7 @@ def __init__( as_closest_canonical: bool = False, labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -484,23 +558,41 @@ def __init__( (2,) sequences are labels for (beginning, end) of output axis. Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See Also: `nibabel.orientations.ornt2axcodes`. """ - super().__init__(keys, allow_missing_keys) - self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.ornt_transform = Orientation( + axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels, lazy=lazy + ) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.ornt_transform.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.ornt_transform.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d: dict = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.ornt_transform(d[key]) + d[key] = self.ornt_transform(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -513,12 +605,20 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Rotate90.backend def __init__( - self, keys: KeysCollection, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False + self, + keys: KeysCollection, + k: int = 1, + spatial_axes: tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -526,19 +626,35 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ - super().__init__(keys, allow_missing_keys) - self.rotator = Rotate90(k, spatial_axes) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.rotator = Rotate90(k, spatial_axes, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rotator.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rotator.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.rotator(d[key]) + d[key] = self.rotator(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -553,6 +669,9 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, La Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Rotate90.backend @@ -564,6 +683,7 @@ def __init__( max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -576,9 +696,12 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.max_k = max_k self.spatial_axes = spatial_axes @@ -589,17 +712,31 @@ def randomize(self, data: Any | None = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None + ) -> Mapping[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ self.randomize() d = dict(data) # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests - rotator = Rotate90(self._rand_k, self.spatial_axes) - rotator.lazy_evaluation = self.lazy_evaluation + lazy_ = self.lazy if lazy is None else lazy + rotator = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_) for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -617,6 +754,9 @@ class Resized(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` @@ -649,6 +789,8 @@ class Resized(MapTransform, InvertibleTransform, LazyTransform): dtype: data type for resampling computation. Defaults to ``float32``. If None, use the data type of input data. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Resize.backend @@ -664,22 +806,37 @@ def __init__( anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys)) self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys)) - self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) + self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.resizer.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.resizer.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma, dtype in self.key_iterator( d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype ): @@ -690,6 +847,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc anti_aliasing=anti_aliasing, anti_aliasing_sigma=anti_aliasing_sigma, dtype=dtype, + lazy=lazy_, ) return d @@ -703,6 +861,9 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch class Affined(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = Affine.backend @@ -722,6 +883,7 @@ def __init__( dtype: DtypeLike | torch.dtype = np.float32, align_corners: bool = False, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -772,6 +934,8 @@ def __init__( align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`monai.transforms.compose.MapTransform` @@ -779,6 +943,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.affine = Affine( rotate_params=rotate_params, shear_params=shear_params, @@ -789,19 +954,33 @@ def __init__( device=device, dtype=dtype, # type: ignore align_corners=align_corners, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.affine.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.affine.lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ + lazy_ = self.lazy if lazy is None else lazy d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode) + d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -814,6 +993,9 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. + + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. """ backend = RandAffine.backend @@ -832,6 +1014,7 @@ def __init__( cache_grid: bool = False, device: torch.device | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: """ Args: @@ -885,6 +1068,8 @@ def __init__( accelerate the transform. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False See also: - :py:class:`monai.transforms.compose.MapTransform` @@ -893,6 +1078,7 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + LazyTransform.__init__(self, lazy=lazy) self.rand_affine = RandAffine( prob=1.0, # because probability handled in this class rotate_range=rotate_range, @@ -902,21 +1088,36 @@ def __init__( spatial_size=spatial_size, cache_grid=cache_grid, device=device, + lazy=lazy, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool) -> None: - self._lazy_evaluation = val - self.rand_affine.lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool) -> None: + self._lazy = val + self.rand_affine.lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffined: self.rand_affine.set_random_state(seed, state) super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor], lazy: bool | None = None + ) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -929,6 +1130,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N item = d[first_key] spatial_size = item.peek_pending_shape() if isinstance(item, MetaTensor) else item.shape[1:] + lazy_ = self.lazy if lazy is None else lazy sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform @@ -936,18 +1138,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid - grid = self.rand_affine.get_identity_grid(sp_size) + grid = self.rand_affine.get_identity_grid(sp_size, lazy=lazy_) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy=lazy_) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform if do_resampling: - d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid) # type: ignore + d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid, lazy=lazy_) # type: ignore else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: @@ -1066,6 +1268,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) @@ -1208,6 +1419,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) @@ -1246,29 +1466,52 @@ class Flipd(MapTransform, InvertibleTransform, LazyTransform): See `numpy.flip` for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend def __init__( - self, keys: KeysCollection, spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False + self, + keys: KeysCollection, + spatial_axis: Sequence[int] | int | None = None, + allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) self.flipper = Flip(spatial_axis=spatial_axis) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1285,11 +1528,16 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr See `numpy.flip` for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Flip.backend @@ -1300,30 +1548,45 @@ def __init__( prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.flipper = Flip(spatial_axis=spatial_axis) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipd: super().set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key], lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1344,31 +1607,50 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, La See `numpy.flip` for additional details. https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. prob: Probability of flipping. allow_missing_keys: don't raise exception if key is missing. - + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = RandAxisFlip.backend - def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False + ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.flipper = RandAxisFlip(prob=1.0) + LazyTransform.__init__(self, lazy=lazy) + self.flipper = RandAxisFlip(prob=1.0, lazy=lazy) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.flipper.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.flipper.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAxisFlipd: super().set_random_state(seed, state) self.flipper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -1379,12 +1661,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc # all the keys share the same random selected axis self.flipper.randomize(d[first_key]) + lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False) + d[key] = self.flipper(d[key], randomize=False, lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1401,6 +1684,9 @@ class Rotated(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. angle: Rotation angle(s) in radians. @@ -1423,6 +1709,8 @@ class Rotated(MapTransform, InvertibleTransform, LazyTransform): the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = Rotate.backend @@ -1437,27 +1725,42 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: - super().__init__(keys, allow_missing_keys) - self.rotator = Rotate(angle=angle, keep_size=keep_size) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.rotator = Rotate(angle=angle, keep_size=keep_size, lazy=lazy) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rotator.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rotator.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): d[key] = self.rotator( - d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ ) return d @@ -1473,6 +1776,9 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. range_x: Range of rotation angle in radians in the plane defined by the first and second axes. @@ -1501,6 +1807,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy the output data type is always ``float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False """ backend = RandRotate.backend @@ -1518,31 +1826,49 @@ def __init__( align_corners: Sequence[bool] | bool = False, dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, + lazy: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_rotate = RandRotate(range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size) + LazyTransform.__init__(self, lazy=lazy) + self.rand_rotate = RandRotate( + range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy + ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rand_rotate.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rand_rotate.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRotated: super().set_random_state(seed, state) self.rand_rotate.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) # all the keys share the same random rotate angle self.rand_rotate.randomize() + lazy_ = self.lazy if lazy is None else lazy + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): @@ -1554,10 +1880,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc align_corners=align_corners, dtype=dtype, randomize=False, + lazy=lazy_, ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1574,6 +1901,9 @@ class Zoomd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. zoom: The zoom factor along the spatial axes. @@ -1598,6 +1928,8 @@ class Zoomd(MapTransform, InvertibleTransform, LazyTransform): If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1615,26 +1947,44 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) + LazyTransform.__init__(self, lazy=lazy) + self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) + self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, lazy=lazy, **kwargs) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.zoomer.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.zoomer.lazy = val + self._lazy = val - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + d[key] = self.zoomer( + d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_ + ) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1648,6 +1998,9 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr """ Dict-based version :py:class:`monai.transforms.RandZoom`. + This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic` + for more information. + Args: keys: Keys to pick data for transformation. prob: Probability of zooming. @@ -1680,9 +2033,10 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr If None, use the data type of input data. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. + lazy: a flag to indicate whether this transform should execute lazily or not. + Defaults to False kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html - """ backend = RandZoom.backend @@ -1699,27 +2053,43 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, keep_size: bool = True, allow_missing_keys: bool = False, + lazy: bool = False, **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, **kwargs) + LazyTransform.__init__(self, lazy=lazy) + self.rand_zoom = RandZoom( + prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs + ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - @LazyTransform.lazy_evaluation.setter # type: ignore - def lazy_evaluation(self, val: bool): - self.rand_zoom.lazy_evaluation = val - self._lazy_evaluation = val + @LazyTransform.lazy.setter # type: ignore + def lazy(self, val: bool): + self.rand_zoom.lazy = val + self._lazy = val def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomd: super().set_random_state(seed, state) self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + lazy: a flag to indicate whether this transform should execute lazily or not + during this call. Setting this to False or True overrides the ``lazy`` flag set + during initialization for this call. Defaults to None. + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) first_key: Hashable = self.first_key(d) if first_key == (): @@ -1730,6 +2100,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc # all the keys share the same random zoom factor self.rand_zoom.randomize(d[first_key]) + lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1742,10 +2113,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc align_corners=align_corners, dtype=dtype, randomize=False, + lazy=lazy_, ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - self.push_transform(d[key], replace=True) + self.push_transform(d[key], replace=True, lazy=lazy_) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1798,7 +2170,6 @@ def __init__( It also can be a sequence, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. - """ super().__init__(keys, allow_missing_keys) self.grid_distortion = GridDistortion(num_cells=num_cells, distort_steps=distort_steps, device=device) @@ -1806,6 +2177,15 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) @@ -1872,6 +2252,15 @@ def set_random_state( return self def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) self.randomize(None) if not self._do_transform: @@ -1922,6 +2311,15 @@ def __init__( self.splitter = GridSplit(grid=grid) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> list[dict[Hashable, NdarrayOrTensor]]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) n_outputs = np.prod(self.grid) output: list[dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] @@ -2003,6 +2401,15 @@ def __init__( ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) for key in self.key_iterator(d): d[key] = self.patcher(d[key]) @@ -2091,6 +2498,15 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Args: + data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified + in this dictionary must be tensor like arrays that are channel first and have at most + three spatial dimensions + + Returns: + a dictionary containing the transformed data, as well as any other data present in the dictionary + """ d = dict(data) # All the keys share the same random noise for key in self.key_iterator(d): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 591ebbb489..9d77a83389 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -66,12 +66,12 @@ def _maybe_new_metatensor(img, dtype=None, device=None): def spatial_resample( - img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, transform_info + img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info ) -> torch.Tensor: """ Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be resampled, assuming `img` is channel-first. @@ -92,6 +92,7 @@ def spatial_resample( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype_pt: data `dtype` for resampling computation. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -132,17 +133,16 @@ def spatial_resample( affine_unchanged = ( allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size) ) or (allclose(xform, np.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size)) - lazy_evaluation = transform_info.get(TraceKeys.LAZY_EVALUATION, False) meta_info = TraceableTransform.track_transform_meta( img, sp_size=spatial_size, - affine=None if affine_unchanged and not lazy_evaluation else xform, + affine=None if affine_unchanged and not lazy else xform, extra_info=extra_info, orig_size=original_spatial_shape, transform_info=transform_info, - lazy_evaluation=lazy_evaluation, + lazy=lazy, ) - if lazy_evaluation: + if lazy: out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if affine_unchanged: @@ -184,17 +184,18 @@ def spatial_resample( return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def orientation(img, original_affine, spatial_ornt, transform_info): +def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> torch.Tensor: """ Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. original_affine: original affine of the input image. spatial_ornt: orientations of the spatial axes, see also https://nipy.org/nibabel/reference/nibabel.orientations.html + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -217,23 +218,23 @@ def orientation(img, original_affine, spatial_ornt, transform_info): extra_info=extra_info, orig_size=spatial_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore if axes: out = torch.flip(out, dims=axes) if not np.all(full_transpose == np.arange(len(out.shape))): out = out.permute(full_transpose.tolist()) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def flip(img, sp_axes, transform_info): +def flip(img, sp_axes, lazy, transform_info): """ Functional implementation of flip. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -242,6 +243,7 @@ def flip(img, sp_axes, transform_info): If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] @@ -255,25 +257,22 @@ def flip(img, sp_axes, transform_info): sp = axis - 1 xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=sp_size, - affine=xform, - extra_info=extra_info, - transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info out = torch.flip(out, axes) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, transform_info): +def resize( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +): """ Functional implementation of resize. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -291,6 +290,7 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` anti_aliasing_sigma: {float, tuple of floats}, optional Standard deviation for Gaussian filtering used when anti-aliasing. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ img = convert_to_tensor(img, track_meta=get_track_meta()) @@ -308,10 +308,10 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, extra_info=extra_info, orig_size=orig_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - if anti_aliasing and transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: + if anti_aliasing and lazy: warnings.warn("anti-aliasing is not compatible with lazy evaluation.") out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -340,11 +340,11 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, transform_info): +def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of rotate. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -360,6 +360,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t dtype: data type for resampling computation. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -393,10 +394,10 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t extra_info=extra_info, orig_size=im_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info _, _m, _p, _ = resolves_modes(mode, padding_mode) xform = AffineTransform( @@ -410,11 +411,11 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, transform_info): +def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of zoom. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -432,6 +433,7 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, dtype: data type for resampling computation. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -447,9 +449,9 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, } if keep_size: do_pad_crop = not np.allclose(output_size, im_shape) - if do_pad_crop and transform_info.get(TraceKeys.LAZY_EVALUATION, False): # update for lazy evaluation + if do_pad_crop and lazy: # update for lazy evaluation _pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode) - _pad_crop.lazy_evaluation = True + _pad_crop.lazy = True _tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1)) _tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform}) lazy_cropped = _pad_crop(_tmp_img) @@ -465,10 +467,10 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, extra_info=extra_info, orig_size=im_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info img_t = out.to(dtype) _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1) @@ -493,17 +495,18 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, return out -def rotate90(img, axes, k, transform_info): +def rotate90(img, axes, k, lazy, transform_info): """ Functional implementation of rotate90. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. If axis is negative it counts from the last to the first axis. k: number of times to rotate by 90 degrees. + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ extra_info = {"axes": [d - 1 for d in axes], "k": k} @@ -533,20 +536,22 @@ def rotate90(img, axes, k, transform_info): extra_info=extra_info, orig_size=ori_shape, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) out = _maybe_new_metatensor(img) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info out = torch.rot90(out, k, axes) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, transform_info): +def affine_func( + img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy, transform_info +): """ Functional implementation of affine. This function operates eagerly or lazily according to - ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + ``lazy`` (default ``False``). Args: img: data to be changed, assuming `img` is channel-first. @@ -570,6 +575,7 @@ def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_re do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but skipping the actual (potentially heavy) resampling operation. image_only: if True return only the image volume, otherwise return (image, affine). + lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ @@ -592,9 +598,9 @@ def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_re extra_info=extra_info, orig_size=img_size, transform_info=transform_info, - lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False), + lazy=lazy, ) - if transform_info.get(TraceKeys.LAZY_EVALUATION, False): + if lazy: out = _maybe_new_metatensor(img) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info return out if image_only else (out, affine) diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py index 0193065562..016effc59d 100644 --- a/monai/transforms/traits.py +++ b/monai/transforms/traits.py @@ -14,7 +14,9 @@ from __future__ import annotations -__all__ = ["LazyTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] +__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] + +from typing import Any class LazyTrait: @@ -27,23 +29,42 @@ class LazyTrait: """ @property - def lazy_evaluation(self): + def lazy(self): """ - Get whether lazy_evaluation is enabled for this transform instance. + Get whether lazy evaluation is enabled for this transform instance. Returns: True if the transform is operating in a lazy fashion, False if not. """ raise NotImplementedError() - @lazy_evaluation.setter - def lazy_evaluation(self, enabled: bool): + @lazy.setter + def lazy(self, enabled: bool): """ - Set whether lazy_evaluation is enabled for this transform instance. + Set whether lazy evaluation is enabled for this transform instance. Args: enabled: True if the transform should operate in a lazy fashion, False if not. """ raise NotImplementedError() + @property + def requires_current_data(self): + """ + Get whether the transform requires the input data to be up to date before the transform executes. + Such transforms can still execute lazily by adding pending operations to the output tensors. + Returns: + True if the transform requires its inputs to be up to date and False if it does not + """ + + +class InvertibleTrait: + """ + An interface to indicate that the transform can be inverted, i.e. undone by performing + the inverse of the operation performed during `__call__`. + """ + + def inverse(self, data: Any) -> Any: + raise NotImplementedError() + class RandomizableTrait: """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3e66431bbc..716970d9be 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -44,27 +44,58 @@ def _apply_transform( - transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False + transform: Callable[..., ReturnType], + data: Any, + unpack_parameters: bool = False, + lazy: bool | None = False, + overrides: dict | None = None, + logger_name: bool | str = False, ) -> ReturnType: """ - Perform transformation `transform` with the provided parameters `parameters`. + Perform a transform 'transform' on 'data', according to the other parameters specified. + + If `data` is a tuple and `unpack_parameters` is True, each parameter of `data` is unpacked + as arguments to `transform`. Otherwise `data` is considered as single argument to `transform`. - If `parameters` is a tuple and `unpack_items` is True, each parameter of `parameters` is unpacked - as arguments to `transform`. - Otherwise `parameters` is considered as single argument to `transform`. + If 'lazy' is True, this method first checks whether it can execute this method lazily. If it + can't, it will ensure that all pending lazy transforms on 'data' are applied before applying + this 'transform' to it. If 'lazy' is True, and 'overrides' are provided, those overrides will + be applied to the pending operations on 'data'. See ``Compose`` for more details on lazy + resampling, which is an experimental feature for 1.2. + + Please note, this class is function is designed to be called by ``apply_transform``. + In general, you should not need to make specific use of it unless you are implementing + pipeline execution mechanisms. Args: transform: a callable to be used to transform `data`. - parameters: parameters for the `transform`. + data: the tensorlike or dictionary of tensorlikes to be executed on unpack_parameters: whether to unpack parameters for `transform`. Defaults to False. + lazy: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + carried out on a transform by transform basis. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. + See the :ref:`Lazy Resampling topic for more information about lazy resampling. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when + :ref:`Lazy Resampling` is enabled for the pipeline or a given transform. If lazy is False + they are ignored. Currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}. + logger_name: this optional parameter allows you to specify a logger by name for logging of pipeline execution. + Setting this to False disables logging. Setting it to True enables logging to the default loggers. + Setting a string overrides the logger name to which logging is performed. Returns: ReturnType: The return type of `transform`. """ - if isinstance(parameters, tuple) and unpack_parameters: - return transform(*parameters) + from monai.transforms.lazy.functional import apply_pending_transforms_in_order + + data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) - return transform(parameters) + if isinstance(data, tuple) and unpack_parameters: + return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + + return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) def apply_transform( @@ -72,7 +103,9 @@ def apply_transform( data: Any, map_items: bool = True, unpack_items: bool = False, - log_stats: bool = False, + log_stats: bool | str = False, + lazy: bool | None = False, + overrides: dict | None = None, ) -> list[ReturnType] | ReturnType: """ Transform `data` with `transform`. @@ -87,9 +120,14 @@ def apply_transform( map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. unpack_items: whether to unpack parameters using `*`. Defaults to False. - log_stats: whether to log the detailed information of data and applied transform when error happened, - for NumPy array and PyTorch Tensor, log the data shape and value range, - for other metadata, log the values directly. default to `False`. + log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which + disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the + default logger name. Setting it to a string specifies the logger to which errors should be logged. + lazy: whether to execute in lazy mode or not. See the :ref:`Lazy Resampling topic for more + information about lazy resampling. + overrides: optional overrides to apply to transform parameters. This parameter is ignored unless transforms + are being executed lazily. See the :ref:`Lazy Resampling topic for more details and + examples of its usage. Raises: Exception: When ``transform`` raises an exception. @@ -99,18 +137,21 @@ def apply_transform( """ try: if isinstance(data, (list, tuple)) and map_items: - return [_apply_transform(transform, item, unpack_items) for item in data] - return _apply_transform(transform, data, unpack_items) + return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data] + return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats and not isinstance(transform, transforms.compose.Compose): + if log_stats is not False and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + if isinstance(log_stats, str): + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) + else: + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) logger = logging.getLogger(datastats._logger_name) - logger.info(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") if isinstance(data, (list, tuple)): data = data[0] @@ -254,17 +295,26 @@ class LazyTransform(Transform, LazyTrait): dictionary transforms to simplify implementation of new lazy transforms. """ - _lazy_evaluation: bool = False + def __init__(self, lazy: bool | None = False): + if lazy is not None: + if not isinstance(lazy, bool): + raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") + self._lazy = lazy @property - def lazy_evaluation(self): - return self._lazy_evaluation + def lazy(self): + return self._lazy - @lazy_evaluation.setter - def lazy_evaluation(self, lazy_evaluation: bool): - if not isinstance(lazy_evaluation, bool): - raise TypeError(f"lazy_evaluation must be a bool but is of type {type(lazy_evaluation)}") - self._lazy_evaluation = lazy_evaluation + @lazy.setter + def lazy(self, lazy: bool | None): + if lazy is not None: + if not isinstance(lazy, bool): + raise TypeError(f"lazy must be a bool but is of type {type(lazy)}") + self._lazy = lazy + + @property + def requires_current_data(self): + return False class RandomizableTransform(Randomizable, Transform): @@ -347,6 +397,7 @@ def __new__(cls, *args, **kwargs): return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + super().__init__() self.keys: tuple[Hashable, ...] = ensure_tuple(keys) self.allow_missing_keys = allow_missing_keys if not self.keys: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 53193f4cb6..09a13945bb 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -28,7 +28,7 @@ from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij -from monai.transforms.compose import Compose, OneOf +from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform, Transform, apply_transform from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, @@ -52,6 +52,7 @@ PytorchPadMode, SplineMode, TraceKeys, + TraceStatusKeys, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -121,6 +122,7 @@ "sync_meta_info", "reset_ops_id", "resolves_modes", + "has_status_keys", ] @@ -1549,6 +1551,7 @@ def get_number_image_type_conversions(transform: Compose, test_data: Any, key: H test_data: data to be used to count the number of conversions key: if using dictionary transforms, this key will be used to check the number of conversions. """ + from monai.transforms.compose import OneOf def _get_data(obj, key): return obj if key is None else obj[key] @@ -1970,5 +1973,80 @@ def resolves_modes( return backend, _interp_mode, _padding_mode, _kwargs +def check_applied_operations(entry: list | dict, status_key: str, default_message: str = "No message provided"): + """ + Check the operations of a MetaTensor to determine whether there are any statuses + Args: + entry: a dictionary that may contain TraceKey.STATUS entries, or a list of such dictionaries + status_key: the status key to search for. This must be an entry in `TraceStatusKeys`_ + default_message: The message to provide if no messages are provided for the given status key entry + + Returns: + A list of status messages matching the providing status key + + """ + if isinstance(entry, list): + results = list() + for sub_entry in entry: + results.extend(check_applied_operations(sub_entry, status_key, default_message)) + return results + else: + status_key_ = TraceStatusKeys(status_key) + if TraceKeys.STATUSES in entry: + if status_key_ in entry[TraceKeys.STATUSES]: + reason = entry[TraceKeys.STATUSES][status_key_] + if reason is None: + return [default_message] + return reason if isinstance(reason, list) else [reason] + return [] + + +def has_status_keys(data: torch.Tensor, status_key: Any, default_msg: str): + """ + Checks whether a given tensor is has a particular status key message on any of its + applied operations. If it doesn't, it returns the tuple `(False, None)`. If it does + it returns a tuple of True and a list of status messages for that status key. + + Status keys are defined in :class:`TraceStatusKeys`. + + This function also accepts: + + * dictionaries of tensors + * lists or tuples of tensors + * list or tuples of dictionaries of tensors + + In any of the above scenarios, it iterates through the collections and executes itself recursively until it is + operating on tensors. + + Args: + data: a `torch.Tensor` or `MetaTensor` or collections of torch.Tensor or MetaTensor, as described above + status_key: the status key to look for, from `TraceStatusKeys` + default_msg: a default message to use if the status key entry doesn't have a message set + + Returns: + A tuple. The first entry is `False` or `True`. The second entry is the status messages that can be used for the + user to help debug their pipelines. + + """ + status_key_occurrences = list() + if isinstance(data, (list, tuple)): + for d in data: + _, reasons = has_status_keys(d, status_key, default_msg) + if reasons is not None: + status_key_occurrences.extend(reasons) + elif isinstance(data, monai.data.MetaTensor): + for op in data.applied_operations: + status_key_occurrences.extend(check_applied_operations(op, status_key, default_msg)) + elif isinstance(data, dict): + for d in data.values(): + _, reasons = has_status_keys(d, status_key, default_msg) + if reasons is not None: + status_key_occurrences.extend(reasons) + + if len(status_key_occurrences) > 0: + return False, status_key_occurrences + return True, None + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 834e4866d7..4a8e439f0a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -55,6 +55,7 @@ SplineMode, StrEnum, TraceKeys, + TraceStatusKeys, TransformBackends, UpsampleMode, Weight, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index a7ea9e29a8..25c747ed90 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -36,6 +36,7 @@ "SkipMode", "Method", "TraceKeys", + "TraceStatusKeys", "CommonKeys", "GanKeys", "PostFix", @@ -316,7 +317,14 @@ class TraceKeys(StrEnum): KEY_SUFFIX: str = "_transforms" NONE: str = "none" TRACING: str = "tracing" - LAZY_EVALUATION: str = "lazy_evaluation" + STATUSES: str = "statuses" + LAZY: str = "lazy" + + +class TraceStatusKeys(StrEnum): + """Enumerable status keys for the TraceKeys.STATUS flag""" + + PENDING_DURING_APPLY = "pending_during_apply" class CommonKeys(StrEnum): diff --git a/tests/croppers.py b/tests/croppers.py index 156600f202..8c9b43bf0a 100644 --- a/tests/croppers.py +++ b/tests/croppers.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Randomizable -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -117,18 +117,19 @@ def crop_test_pending_ops(self, input_param, input_shape, align_corners=False): expected = result_non_lazy["img"] if is_map else result_non_lazy self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(input_data) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(crop_fn, MapTransform): - crop_fn.lazy_evaluation = False + crop_fn.lazy = False inverted = crop_fn.inverse(result) self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations)) self.assertEqual(inverted.shape, im.shape) @@ -155,7 +156,7 @@ def crop_test_combine_ops(self, funcs, input_shape): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True if isinstance(_func, Randomizable): _func.set_random_state(seed=123) pending_result = _func(pending_result) @@ -164,7 +165,8 @@ def crop_test_combine_ops(self, funcs, input_shape): assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/lazy_transforms_utils.py b/tests/lazy_transforms_utils.py index f1f8708285..1681e26037 100644 --- a/tests/lazy_transforms_utils.py +++ b/tests/lazy_transforms_utils.py @@ -15,7 +15,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import InvertibleTransform, MapTransform, Randomizable -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import assert_allclose apply_transforms_kwargs = ("pending", "mode", "padding_mode", "dtype", "align_corners") @@ -61,7 +61,7 @@ def test_resampler_lazy( if isinstance(resampler, Randomizable): resampler.set_random_state(seed=seed) set_track_meta(True) - resampler.lazy_evaluation = True + resampler.lazy = True pending_output = resampler(**deepcopy(call_param)) if output_idx is not None: expected_output, pending_output = expected_output[output_idx], pending_output[output_idx] @@ -73,7 +73,7 @@ def test_resampler_lazy( if not skip_shape_check: assert_allclose(lazy_out.peek_pending_shape(), non_lazy_out.shape[1:4]) apply_param = get_apply_param(init_param, call_param) - lazy_out = apply_transforms(lazy_out, **apply_param)[0] + lazy_out = apply_pending(lazy_out, overrides=apply_param)[0] assert_allclose(lazy_out, non_lazy_out, rtol=rtol, atol=atol) if ( isinstance(resampler, InvertibleTransform) @@ -82,10 +82,10 @@ def test_resampler_lazy( and isinstance(non_lazy_out, MetaTensor) and non_lazy_out.applied_operations ): - resampler.lazy_evaluation = False + resampler.lazy = False out = resampler.inverse(lazy_out.clone()) ref = resampler.inverse(non_lazy_out.clone()) assert_allclose(out.applied_operations, []) assert_allclose(out.pending_operations, []) assert_allclose(ref, out, type_test=False, rtol=1e-3, atol=1e-3) - resampler.lazy_evaluation = True + resampler.lazy = True diff --git a/tests/padders.py b/tests/padders.py index e21faabc10..02d7b40af6 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from monai.utils.enums import NumpyPadMode, PytorchPadMode from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -127,18 +127,19 @@ def pad_test_pending_ops(self, input_param, input_shape): expected = result_non_lazy["img"] if is_map else result_non_lazy self.assertIsInstance(expected, MetaTensor) # lazy - pad_fn.lazy_evaluation = True + pad_fn.lazy = True pending_result = pad_fn(input_data) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): - pad_fn.lazy_evaluation = False + pad_fn.lazy = False inverted = pad_fn.inverse(result) self.assertTrue((not inverted.pending_operations) and (not inverted.applied_operations)) self.assertEqual(inverted.shape, im.shape) @@ -161,13 +162,14 @@ def pad_test_combine_ops(self, funcs, input_shape, expected_shape): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True pending_result = _func(pending_result) pending_result = pending_result["img"] if is_map else pending_result self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # TODO: mode="bilinear" may report error - result = apply_transforms(pending_result, mode="nearest", padding_mode=mode[1], align_corners=False)[0] + overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_affine.py b/tests/test_affine.py index e8f7f33b17..9c2f4197a6 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -20,7 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Affine, Resize -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.utils import optional_import from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion @@ -208,16 +208,18 @@ def test_affine_resize(self, s): def method_0(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + overrides = {"padding_mode": "border", "align_corners": ac} + out = apply_pending(out, overrides=overrides)[0] return out def method_1(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + overrides = {"mode": 1, "padding_mode": "nearest", "align_corners": ac} + out = apply_pending(out, overrides=overrides)[0] return out def method_2(im, ac): diff --git a/tests/test_apply.py b/tests/test_apply.py index cf74721267..4784d46413 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.utils import create_rotate from monai.utils import LazyAttr, convert_to_tensor from tests.utils import get_arange_img @@ -40,20 +40,20 @@ def single_2d_transform_cases(): class TestApply(unittest.TestCase): def _test_apply_impl(self, tensor, pending_transforms, expected_shape): - result = apply_transforms(tensor, pending_transforms) + result = apply_pending(tensor, pending_transforms) self.assertListEqual(result[1], pending_transforms) self.assertEqual(result[0].shape, expected_shape) def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): tensor_ = convert_to_tensor(tensor, track_meta=True) if pending_as_parameter: - result, transforms = apply_transforms(tensor_, pending_transforms) + result, transforms = apply_pending(tensor_, pending_transforms) else: for p in pending_transforms: tensor_.push_pending_operation(p) if not isinstance(p, dict): return - result, transforms = apply_transforms(tensor_) + result, transforms = apply_pending(tensor_) self.assertEqual(result.shape, expected_shape) SINGLE_TRANSFORM_CASES = single_2d_transform_cases() diff --git a/tests/test_compose.py b/tests/test_compose.py index 65b9d8fbfb..892e9a23e7 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -11,21 +11,33 @@ from __future__ import annotations +import logging import sys import unittest from copy import deepcopy +from io import StringIO import numpy as np import torch from parameterized import parameterized +import monai.transforms as mt from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose, Flip, NormalizeIntensity, Rotate, Rotate90, Rotated, Zoom from monai.transforms.compose import execute_compose from monai.transforms.transform import Randomizable from monai.utils import set_determinism +def data_from_keys(keys, h, w): + if keys is None: + data = torch.arange(h * w).reshape(1, h, w) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.arange(h * w).reshape(1, h, w).mul_(i_k * h * w) + return data + + class _RandXform(Randomizable): def randomize(self): self.val = self.R.random_sample() @@ -37,7 +49,7 @@ def __call__(self, __unused): class TestCompose(unittest.TestCase): def test_empty_compose(self): - c = Compose() + c = mt.Compose() i = 1 self.assertEqual(c(i), 1) @@ -48,7 +60,7 @@ def a(i): def b(i): return i + "b" - c = Compose([a, b, a, b]) + c = mt.Compose([a, b, a, b]) self.assertEqual(c(""), "abab") def test_dict_compose(self): @@ -66,7 +78,7 @@ def b(d): data = {"a": 0, "b": 0} expected = {"a": 3, "b": 2} - self.assertDictEqual(Compose(transforms)(data), expected) + self.assertDictEqual(mt.Compose(transforms)(data), expected) self.assertDictEqual(execute_compose(data, transforms), expected) def test_list_dict_compose(self): @@ -89,7 +101,7 @@ def c(d): # transform to handle dict data transforms = [a, a, b, c, c] data = {"a": 0, "b": 0, "c": 0} expected = {"a": 2, "b": 1, "c": 2} - value = Compose(transforms)(data) + value = mt.Compose(transforms)(data) for item in value: self.assertDictEqual(item, expected) value = execute_compose(data, transforms) @@ -106,7 +118,7 @@ def b(i, i2): transforms = [a, b, a, b] data = ("", "") expected = ("abab", "a2b2a2b2") - self.assertEqual(Compose(transforms, map_items=False, unpack_items=True)(data), expected) + self.assertEqual(mt.Compose(transforms, map_items=False, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected) def test_list_non_dict_compose_with_unpack(self): @@ -119,7 +131,7 @@ def b(i, i2): transforms = [a, b, a, b] data = [("", ""), ("t", "t")] expected = [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")] - self.assertEqual(Compose(transforms, unpack_items=True)(data), expected) + self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) def test_list_dict_compose_no_map(self): @@ -143,7 +155,7 @@ def c(d): # transform to handle dict data transforms = [a, a, b, c, c] data = {"a": 0, "b": 0, "c": 0} expected = {"a": 2, "b": 1, "c": 2} - value = Compose(transforms, map_items=False)(data) + value = mt.Compose(transforms, map_items=False)(data) for item in value: self.assertDictEqual(item, expected) value = execute_compose(data, transforms, map_items=False) @@ -161,7 +173,7 @@ def __call__(self, data): self.randomize() return self.rand + data - c = Compose([_Acc(), _Acc()]) + c = mt.Compose([_Acc(), _Acc()]) self.assertNotAlmostEqual(c(0), c(0)) c.set_random_state(123) self.assertAlmostEqual(c(1), 1.61381597) @@ -177,17 +189,17 @@ def randomize(self, foo1, foo2): def __call__(self, data): pass - c = Compose([_RandomClass(), _RandomClass()]) + c = mt.Compose([_RandomClass(), _RandomClass()]) with self.assertWarns(Warning): c.randomize() def test_err_msg(self): - transforms = Compose([abs, AddChannel(), round], log_stats=False) + transforms = mt.Compose([abs, mt.AddChannel(), round]) with self.assertRaisesRegex(Exception, "AddChannel"): transforms(42.1) def test_data_loader(self): - xform_1 = Compose([_RandXform()]) + xform_1 = mt.Compose([_RandXform()]) train_ds = Dataset([1], transform=xform_1) xform_1.set_random_state(123) @@ -211,7 +223,7 @@ def test_data_loader(self): def test_data_loader_2(self): set_determinism(seed=123) - xform_2 = Compose([_RandXform(), _RandXform()]) + xform_2 = mt.Compose([_RandXform(), _RandXform()]) train_ds = Dataset([1], transform=xform_2) out_2 = train_ds[0] @@ -232,42 +244,37 @@ def test_data_loader_2(self): set_determinism(None) def test_flatten_and_len(self): - x = AddChannel() - t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + x = mt.AddChannel() + t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])]) t2 = t1.flatten() for t in t2.transforms: - self.assertNotIsInstance(t, Compose) + self.assertNotIsInstance(t, mt.Compose) # test len self.assertEqual(len(t1), 8) def test_backwards_compatible_imports(self): - from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 + from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], - [None, (Rotate(np.pi / 8),)], - [None, (Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())], - [("a",), (Rotated(("a",), np.pi / 8),)], + [None, (mt.Rotate(np.pi / 8),)], + [None, (mt.Flip(0), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity())], + [("a",), (mt.Rotated(("a",), np.pi / 8),)], ] class TestComposeExecute(unittest.TestCase): @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) def test_compose_execute_equivalence(self, keys, pipeline): - if keys is None: - data = torch.unsqueeze(torch.tensor(np.arange(24 * 32).reshape(24, 32)), axis=0) - else: - data = {} - for i_k, k in enumerate(keys): - data[k] = torch.unsqueeze(torch.tensor(np.arange(24 * 32)).reshape(24, 32) + i_k * 768, axis=0) + data = data_from_keys(keys, 12, 16) - expected = Compose(deepcopy(pipeline))(data) + expected = mt.Compose(deepcopy(pipeline))(data) for cutoff in range(len(pipeline)): - c = Compose(deepcopy(pipeline)) + c = mt.Compose(deepcopy(pipeline)) actual = c(c(data, end=cutoff), start=cutoff) if isinstance(actual, dict): for k in actual.keys(): @@ -283,6 +290,302 @@ def test_compose_execute_equivalence(self, keys, pipeline): else: self.assertTrue(torch.allclose(expected, actual)) + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_bad_start_param(self, keys, pipeline): + data = data_from_keys(keys, 12, 16) + + c = mt.Compose(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, start=None) + with self.assertRaises(ValueError): + c(data, start=None) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=None) + + c = mt.Compose(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, start=-1) + with self.assertRaises(ValueError): + c(data, start=-1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=-1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_negative_range(self, keys, pipeline): + data = data_from_keys(keys, 12, 16) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, start=2, end=1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), start=2, end=1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_bad_end_param(self, keys, pipeline): + data = data_from_keys(keys, 12, 16) + + with self.assertRaises(ValueError): + c = mt.Compose(deepcopy(pipeline)) + c(data, end=len(pipeline) + 1) + + with self.assertRaises(ValueError): + execute_compose(data, deepcopy(pipeline), end=len(pipeline) + 1) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_empty_range(self, keys, pipeline): + data = data_from_keys(keys, 12, 16) + + c = mt.Compose(deepcopy(pipeline)) + for i in range(len(pipeline)): + result = c(data, start=i, end=i) + self.assertIs(data, result) + + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_with_logger(self, keys, pipeline): + data = data_from_keys(keys, 12, 16) + + c = mt.Compose(deepcopy(pipeline), log_stats="a_logger_name") + c(data) + + +TEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES = [ + [ + None, + (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()), + False, + ( + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Spacing', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Zoom', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + ), + ], + [ + None, + ( + mt.Flip(0, lazy=True), + mt.Spacing((1.2, 1.2), lazy=True), + mt.Flip(1, lazy=True), + mt.Rotate90(1), + mt.Zoom(0.8, lazy=True), + mt.NormalizeIntensity(), + ), + None, + ( + "INFO - Accumulate pending transforms - lazy: None, pending: 0, " + "upcoming 'Flip', transform.lazy: True\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 1, " + "upcoming 'Spacing', transform.lazy: True\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 2, " + "upcoming 'Flip', transform.lazy: True\n" + "INFO - Apply pending transforms - lazy: None, pending: 3, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Pending transforms applied: applied_operations: 3\n" + "INFO - Accumulate pending transforms - lazy: None, pending: 0, " + "upcoming 'Zoom', transform.lazy: True\n" + "INFO - Apply pending transforms - lazy: None, pending: 1, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + None, + (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Spacing', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 2, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 3, " + "upcoming 'Rotate90', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 4, " + "upcoming 'Zoom', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 5, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(("a", "b"), 0), + mt.Spacingd(("a", "b"), 1.2), + mt.Rotate90d(("a", "b"), 1), + mt.NormalizeIntensityd(("a",)), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 2, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 2, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, key: 'a', pending: 3, " + "upcoming 'NormalizeIntensityd', transform is not lazy\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 3\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(keys="a", spatial_axis=0), + mt.Rotate90d(keys="b", k=1, allow_missing_keys=True), + mt.Zoomd(keys=("a", "b"), zoom=0.8, allow_missing_keys=True), + mt.Spacingd(keys="a", pixdim=1.2), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 1, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 2, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 2\n" + ), + ], + [ + None, + ( + mt.Flip(0), + mt.Spacing((1.2, 1.2)), + mt.Flip(1), + mt.ApplyPending(), + mt.Rotate90(1), + mt.Zoom(0.8), + mt.NormalizeIntensity(), + ), + False, + ( + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Spacing', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Flip', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'ApplyPending', transform is not lazy\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'Zoom', transform.lazy: False\n" + "INFO - Apply pending transforms - lazy: False, pending: 0, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + ), + ], + [ + None, + ( + mt.Flip(0), + mt.Spacing((1.2, 1.2)), + mt.Flip(1), + mt.ApplyPending(), + mt.Rotate90(1), + mt.Zoom(0.8), + mt.NormalizeIntensity(), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Spacing', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 2, " + "upcoming 'Flip', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 3, " + "upcoming 'ApplyPending', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 3\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 0, " + "upcoming 'Rotate90', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, pending: 1, " + "upcoming 'Zoom', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, pending: 2, " + "upcoming 'NormalizeIntensity', transform is not lazy\n" + "INFO - Pending transforms applied: applied_operations: 5\n" + ), + ], + [ + ("a", "b"), + ( + mt.Flipd(keys="a", spatial_axis=0), + mt.Rotate90d(keys="b", k=1, allow_missing_keys=True), + mt.ApplyPendingd(keys=("a", "b")), + mt.Zoomd(keys=("a", "b"), zoom=0.8, allow_missing_keys=True), + mt.Spacingd(keys="a", pixdim=1.2), + ), + True, + ( + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, " + "upcoming 'Flipd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, " + "upcoming 'Rotate90d', transform.lazy: False (overridden)\n" + "INFO - Apply pending transforms - lazy: True, key: 'a', pending: 1, " + "upcoming 'ApplyPendingd', transform is not lazy\n" + "INFO - Apply pending transforms - lazy: True, key: 'b', pending: 1, " + "upcoming 'ApplyPendingd', transform is not lazy\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 1\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 1\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, " + "upcoming 'Zoomd', transform.lazy: False (overridden)\n" + "INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, " + "upcoming 'Spacingd', transform.lazy: False (overridden)\n" + "INFO - Pending transforms applied: key: 'a', applied_operations: 3\n" + "INFO - Pending transforms applied: key: 'b', applied_operations: 2\n" + ), + ], +] + + +class TestComposeExecuteWithLogging(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES) + def test_compose_with_logging(self, keys, pipeline, lazy, expected): + stream = StringIO() + handler = logging.StreamHandler(stream) + formatter = logging.Formatter("%(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger = logging.getLogger("a_logger_name") + logger.setLevel(logging.INFO) + while len(logger.handlers) > 0: + logger.removeHandler(logger.handlers[-1]) + logger.addHandler(handler) + + data = data_from_keys(keys, 12, 16) + c = mt.Compose(deepcopy(pipeline), lazy=lazy, log_stats="a_logger_name") + c(data) + + handler.flush() + actual = stream.getvalue() + self.assertEqual(actual, expected) + class TestOps: @staticmethod @@ -318,10 +621,10 @@ def _inner(data1, data2): class TestComposeExecuteWithFlags(unittest.TestCase): @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES) def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): - expected = Compose(pipeline, **flags)(data) + expected = mt.Compose(pipeline, **flags)(data) for cutoff in range(len(pipeline)): - c = Compose(deepcopy(pipeline), **flags) + c = mt.Compose(deepcopy(pipeline), **flags) actual = c(c(data, end=cutoff), start=cutoff) if isinstance(actual, dict): for k in actual.keys(): @@ -338,16 +641,5 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): self.assertTrue(expected, actual) -TEST_LAZY_COMPOSE_PIPELINE_FIX_CASES = [[(Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())]] - - -class TestLazyComposePipelineFixes(unittest.TestCase): - @parameterized.expand(TEST_LAZY_COMPOSE_PIPELINE_FIX_CASES) - def test_lazy_compose_pipeline_fixes(self, pipeline): - data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) - c = Compose(deepcopy(pipeline), lazy_evaluation=True) - _ = c(data) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 1ffdc9983e..4435b128ba 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -19,7 +19,7 @@ from monai.config import USE_COMPILED from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForeground -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_COORDS, TESTS, TEST_LAZY_ERROR = [], [], [] @@ -126,13 +126,14 @@ def test_pending_ops(self, input_param, image, _expected_data, align_corners): expected = crop_fn(image) self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) @@ -142,17 +143,18 @@ def test_lazy_error(self, input_param, image, _expected_data, align_corners): with self.assertRaises(ValueError): crop_fn = CropForeground(**input_param) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) - return apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + return apply_pending(pending_result, overrides=overrides)[0] @parameterized.expand(TEST_COORDS + TESTS) def test_inverse_pending_ops(self, input_param, image, _expected_data, align_corners): crop_fn = CropForeground(**input_param) - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image) self.assertIsInstance(pending_result, MetaTensor) - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": align_corners})[0] inverted = crop_fn.inverse(result) self.assertEqual(image.shape, inverted.shape) self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations)) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index d2604ef9cf..776776f6c5 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import CropForegroundd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_POSITION, TESTS = [], [] @@ -189,13 +189,14 @@ def test_pending_ops(self, input_param, image, _expected_data, align_corners): expected = crop_fn(image)["img"] self.assertIsInstance(expected, MetaTensor) # lazy - crop_fn.lazy_evaluation = True + crop_fn.lazy = True pending_result = crop_fn(image)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=align_corners)[0] + overrides = {"mode": "nearest", "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_flip.py b/tests/test_flip.py index 287852c2c1..d7df55fde0 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -61,7 +61,7 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): init_param = {"spatial_axis": spatial_axis} xform = Flip(**init_param) call_param = {"img": img} - res = xform(**call_param) + res = xform(**call_param) # type: ignore[arg-type] self.assertEqual(img.shape, res.shape) if track_meta: test_resampler_lazy(xform, res, init_param, call_param) diff --git a/tests/test_integration_lazy_samples.py b/tests/test_integration_lazy_samples.py index 88e256f2dc..c365616bc8 100644 --- a/tests/test_integration_lazy_samples.py +++ b/tests/test_integration_lazy_samples.py @@ -24,7 +24,8 @@ import monai import monai.transforms as mt from monai.data import create_test_image_3d, decollate_batch -from monai.utils import set_determinism +from monai.transforms.utils import has_status_keys +from monai.utils import TraceStatusKeys, set_determinism from tests.utils import HAS_CUPY, DistTestCase, SkipIfBeforePyTorchVersion, skip_if_quick @@ -41,9 +42,10 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, num_workers = 0 if torch.cuda.is_available() else num_workers # define transforms for image and segmentation - lazy_kwargs = dict( - mode=("bilinear", 0), device=device, padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8) - ) + lazy_kwargs = { + "img": {"mode": "bilinear", "device": device, "padding_mode": "border", "dtype": torch.float32}, + "seg": {"mode": 0, "device": device, "padding_mode": "nearest", "dtype": torch.uint8}, + } train_transforms = mt.Compose( [ mt.LoadImaged(keys=["img", "seg"], reader=readers[0], image_only=True), @@ -58,7 +60,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, mt.Orientationd(keys=["img", "seg"], axcodes="ARS"), mt.RandRotate90d(keys=["img", "seg"], prob=1.0, spatial_axes=(1, 2)), mt.ScaleIntensityd(keys="img"), - mt.IdentityD(keys=["seg"]), + mt.ApplyPendingd(keys=["seg"]), mt.RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[76, 82, 80], pos=1, neg=1, num_samples=4 ), @@ -70,10 +72,9 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, mt.Rotated(keys=["img", "seg"], angle=[np.pi / 2, np.pi / 2, 0], mode="nearest", keep_size=False), mt.Lambdad(keys=["img"], func=_no_op), ], - lazy_evaluation=lazy, + lazy=lazy, overrides=lazy_kwargs, - override_keys=("img", "seg"), - verbose=num_workers > 0, # testing both flags + log_stats=num_workers > 0, ) # create a training data loader @@ -115,6 +116,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, train_ds, batch_size=1, shuffle=True, num_workers=num_workers, generator=_g, persistent_workers=num_workers > 0 ) all_coords = set() + batch_data = None for epoch in range(5): print("-" * 10) print(f"Epoch {epoch + 1}/5") @@ -149,16 +151,9 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, saver(item) # just testing the saving saver(in_img) saver(in_seg) - if lazy: - inverted = 0 - try: - inverted = [inverter(b_data) for b_data in decollate_batch(batch_data)] - except RuntimeError as e: - if "Lambda" in str(e): - inverted = None - assert inverted is None, "invert LambdaD + lazy is not supported" - else: - [inverter(b_data) for b_data in decollate_batch(batch_data)] # expecting no error + invertible, reasons = has_status_keys(batch_data, TraceStatusKeys.PENDING_DURING_APPLY) + inverted = [inverter(b_data) for b_data in decollate_batch(batch_data)] # expecting no error + return ops @@ -193,12 +188,13 @@ def train_and_infer(self, idx=0): elif idx == 2: _readers = ("itkreader", "nibabelreader") _w = 0 - results = run_training_test( - self.data_dir, device=self.device, cachedataset=idx, readers=_readers, num_workers=_w, lazy=True - ) + results_expected = run_training_test( self.data_dir, device=self.device, cachedataset=0, readers=_readers, num_workers=_w, lazy=False ) + results = run_training_test( + self.data_dir, device=self.device, cachedataset=idx, readers=_readers, num_workers=_w, lazy=True + ) self.assertFalse(np.allclose(results, [0])) self.assertFalse(np.allclose(results_expected, [0])) np.testing.assert_allclose(results, results_expected) @@ -213,7 +209,6 @@ def train_and_infer(self, idx=0): diffs.append(diff_rate) np.testing.assert_allclose(diff_rate, 0.0, atol=0.03) print("volume diff:", diffs) - return results def test_training(self): for i in range(4): diff --git a/tests/test_invert.py b/tests/test_invert.py index 0d53b4bf61..b7c11362ce 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -90,16 +90,15 @@ def test_invert(self): set_determinism(seed=None) def test_invert_warn_pending(self): + # this test shouldn't raise a warning or error any more as that issue was fixed + # by https://github.com/Project-MONAI/MONAI/pull/6257 set_determinism(seed=0) im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete transform = Compose( - [LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Lambda(func=lambda x: x)], - lazy_evaluation=True, + [LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Lambda(func=lambda x: x)], lazy=True ) output = transform([im_fname for _ in range(2)]) - with self.assertRaises(RuntimeError): # transform id mismatch because of lambda - with self.assertWarns(Warning): # warning of wrong ordering lazy + nonlazy_invertible - transform.inverse(output) + transform.inverse(output) if __name__ == "__main__": diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py index 1e7bea17d4..574fd49592 100644 --- a/tests/test_nvtx_decorator.py +++ b/tests/test_nvtx_decorator.py @@ -62,14 +62,7 @@ ] TEST_CASE_RECURSIVE_2 = [ torch.randn(3, 3), - Compose( - [ - ToNumpy(), - Flip(), - OneOf([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)], weights=[0, 1], log_stats=True), - ToTensor(), - ] - ), + Compose([ToNumpy(), Flip(), OneOf([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)], weights=[0, 1]), ToTensor()]), ] TEST_CASE_RECURSIVE_LIST = [ torch.randn(3, 3), @@ -167,7 +160,6 @@ def test_recursive_tranforms(self, input, transforms): # Check the outputs self.assertEqual(transforms.map_items, transforms_range.map_items) self.assertEqual(transforms.unpack_items, transforms_range.unpack_items) - self.assertEqual(transforms.log_stats, transforms_range.log_stats) np.testing.assert_equal(output.numpy(), output_r.numpy()) @parameterized.expand([TEST_CASE_RECURSIVE_LIST]) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 36980c23a7..2909597507 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -15,8 +15,12 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import ( InvertibleTransform, @@ -227,5 +231,41 @@ def test_one_of(self): self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0) +TEST_ONEOF_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestOneOfAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_ONEOF_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + c = OneOf(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, start=1) + with self.assertRaises(ValueError): + c(data, start=1) + + c = OneOf(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, end=1) + with self.assertRaises(ValueError): + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_orientation.py b/tests/test_orientation.py index 6e89d085d2..aa1c326bdf 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -190,7 +190,7 @@ def test_ornt_meta( img = MetaTensor(img, affine=affine).to(device) ornt = Orientation(**init_param) call_param = {"data_array": img} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] if img.ndim in (3, 4): test_resampler_lazy(ornt, res, init_param, call_param) diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index ddb5dc3e98..cf4eb23d42 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -74,7 +74,7 @@ def test_orntd( img = MetaTensor(img, affine=affine) img = img.to(device) call_param = {"data": {k: img.clone() for k in ornt.keys}} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] for k in ornt.keys: if img.ndim in (3, 4): test_resampler_lazy(ornt, res, init_param, call_param, output_key=k) @@ -92,7 +92,7 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi expected_shape = img.shape expected_code = ornt.ornt_transform.axcodes call_param = {"data": {k: img.clone() for k in ornt.keys}} - res = ornt(**call_param) + res = ornt(**call_param) # type: ignore[arg-type] for k in ornt.keys: _im = res[k] np.testing.assert_allclose(_im.shape, expected_shape) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 5c1e2359e8..a607029c1a 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -234,7 +234,7 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): resampler = RandAffined(**lazy_init_param).set_random_state(123) expected_output = resampler(**call_param) test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key) - resampler.lazy_evaluation = False + resampler.lazy = False if input_param.get("cache_grid", False): self.assertTrue(g.rand_affine._cached_grid is not None) diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 457617fc19..81e42372db 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -33,7 +33,7 @@ def test_correct_results(self): # test lazy test_resampler_lazy(flip, result, call_param=call_param, seed=321) - flip.lazy_evaluation = False + flip.lazy = False expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] assert_allclose(result, p(np.stack(expected)), type_test="tensor") diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index e6fac5637f..75357b23e1 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -33,7 +33,7 @@ def test_correct_results(self): # test lazy test_resampler_lazy(flip, result, call_param=call_param, output_key="img", seed=1234) - flip.lazy_evaluation = False + flip.lazy = False test_local_inversion(flip, result, {"img": im}, "img") expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 6723dfc4c6..88d2631ca5 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndices, RandCropByLabelClasses -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS_INDICES, TESTS_SHAPE = [], [] @@ -154,14 +154,14 @@ def test_pending_ops(self, input_param, input_data, _expected_type, _expected_sh self.assertIsInstance(expected[0], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(**input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 8af1df5c42..748f26f1ff 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [] @@ -143,14 +143,14 @@ def test_pending_ops(self, input_param, input_data, _expected_type, _expected_sh self.assertIsInstance(expected[0]["img"], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result["img"], overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i]["img"], rtol=1e-5) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e1c4cdff58..98af6b0b5e 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabel -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ @@ -136,14 +136,14 @@ def test_pending_ops(self, input_param, input_data, _expected_shape): self.assertIsInstance(expected[0], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(**input_data_mod) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 11b7960617..1b57548d12 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandCropByPosNegLabeld -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TESTS = [ @@ -153,15 +153,16 @@ def test_pending_ops(self, input_param, input_data, _expected_shape): self.assertIsInstance(expected[0]["image"], MetaTensor) # lazy cropper.set_random_state(0) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data_mod) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["image"], MetaTensor) assert_allclose(_pending_result["image"].peek_pending_affine(), expected[i]["image"].affine) assert_allclose(_pending_result["image"].peek_pending_shape(), expected[i]["image"].shape[1:]) # only support nearest - result_image = apply_transforms(_pending_result["image"], mode="nearest", align_corners=False)[0] - result_extra = apply_transforms(_pending_result["extra"], mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result_image = apply_pending(_pending_result["image"], overrides=overrides)[0] + result_extra = apply_pending(_pending_result["extra"], overrides=overrides)[0] # compare assert_allclose(result_image, expected[i]["image"], rtol=1e-5) assert_allclose(result_extra, expected[i]["extra"], rtol=1e-5) diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index d67b4ca31b..be5394c172 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -37,7 +37,7 @@ def test_correct_results(self, _, spatial_axis): # test lazy test_resampler_lazy(flip, result, init_param, call_param, output_key="img") - flip.lazy_evaluation = False + flip.lazy = False expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 8bd697efe5..ca3eda3b12 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -91,7 +91,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, # test lazy test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) - rotate_fn.lazy_evaluation = False + rotate_fn.lazy = False _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -133,7 +133,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, # test lazy test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243) - rotate_fn.lazy_evaluation = False + rotate_fn.lazy = False assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) test_local_inversion(rotate_fn, rotated, im) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 2504c0f01b..88f88bf422 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -37,7 +37,7 @@ def test_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) - rotate.lazy_evaluation = False + rotate.lazy = False def test_k(self): init_param = {"max_k": 2} @@ -60,7 +60,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123) - rotate.lazy_evaluation = False + rotate.lazy = False def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0) @@ -71,7 +71,7 @@ def test_spatial_axes(self): rotated = rotate(**call_param) # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1234) - rotate.lazy_evaluation = False + rotate.lazy = False self.assertEqual(len(rotated.applied_operations), 1) expected = [np.rot90(channel, rotate._rand_k, (0, 1)) for channel in self.imt[0]] @@ -88,7 +88,7 @@ def test_prob_k_spatial_axes(self): rotated = rotate(**call_param) # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index f811f1a6a6..23e9025c08 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -34,7 +34,7 @@ def test_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1323, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -58,7 +58,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -76,7 +76,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] @@ -94,7 +94,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index a0d56bcaf3..df121e2220 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandScaleCrop, RandSpatialCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -84,13 +84,13 @@ def test_random_shape(self, input_param, input_shape, expected_shape): # lazy # reset random seed to ensure the same results cropper.set_random_state(seed=123) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 69d2e5af5d..92f0f9d9be 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandSpatialCropSamples -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -112,14 +112,14 @@ def test_pending_ops(self, input_param, input_shape, _expected_shape, _expected_ self.assertIsInstance(expected[0], MetaTensor) # lazy xform.set_random_state(1234) - xform.lazy_evaluation = True + xform.lazy = True pending_result = xform(image) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index fc6e6c8c43..ec0d63cc50 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose TEST_CASE_1 = [ @@ -122,15 +122,16 @@ def test_pending_ops(self, input_param, input_data, _expected_shape, _expected_l # lazy xform.set_random_state(1234) - xform.lazy_evaluation = True + xform.lazy = True pending_result = xform(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result_img = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] - result_seg = apply_transforms(_pending_result["seg"], mode="nearest", align_corners=False)[0] + overrides = {"mode": "nearest", "align_corners": False} + result_img = apply_pending(_pending_result["img"], overrides=overrides)[0] + result_seg = apply_pending(_pending_result["seg"], overrides=overrides)[0] # compare assert_allclose(result_img, expected[i]["img"], rtol=1e-5) assert_allclose(result_seg, expected[i]["seg"], rtol=1e-5) diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 5114a45159..123459235f 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import RandScaleCropd, RandSpatialCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, assert_allclose @@ -89,13 +89,13 @@ def test_random_shape(self, input_param, input_shape, expected_shape): # lazy # reset random seed to ensure the same results cropper.set_random_state(seed=123) - cropper.lazy_evaluation = True + cropper.lazy = True pending_result = cropper(input_data)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms(pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index e279f29f68..47a8f3bfa2 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.array import RandWeightedCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.croppers import CropTest from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -178,14 +178,14 @@ def test_pending_ops(self, _, input_param, img, weight, expected_shape, expected self.assertIsInstance(expected[0], MetaTensor) # lazy crop.set_random_state(10) - crop.lazy_evaluation = True + crop.lazy = True pending_result = crop(img, weight) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result, MetaTensor) assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine) assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:]) # only support nearest - result = apply_transforms(_pending_result, mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result, overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i], rtol=1e-5) diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 51e1b15c2c..9d37779613 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -18,7 +18,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms.croppad.dictionary import RandWeightedCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose @@ -166,14 +166,14 @@ def test_pending_ops(self, _, input_param, input_data, expected_shape, expected_ self.assertIsInstance(expected[0]["img"], MetaTensor) # lazy crop.set_random_state(10) - crop.lazy_evaluation = True + crop.lazy = True pending_result = crop(input_data) for i, _pending_result in enumerate(pending_result): self.assertIsInstance(_pending_result["img"], MetaTensor) assert_allclose(_pending_result["img"].peek_pending_affine(), expected[i]["img"].affine) assert_allclose(_pending_result["img"].peek_pending_shape(), expected[i]["img"].shape[1:]) # only support nearest - result = apply_transforms(_pending_result["img"], mode="nearest", align_corners=False)[0] + result = apply_pending(_pending_result["img"], overrides={"mode": "nearest", "align_corners": False})[0] # compare assert_allclose(result, expected[i]["img"], rtol=1e-5) diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index f080056b63..bb0495c793 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -58,7 +58,7 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz test_resampler_lazy( random_zoom, zoomed, init_param, call_param, key, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6 ) - random_zoom.lazy_evaluation = False + random_zoom.lazy = False test_local_inversion(random_zoom, zoomed, {key: im}, key) expected = [ diff --git a/tests/test_random_order.py b/tests/test_random_order.py index 9ed22d30ae..e5507fafca 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -12,9 +12,15 @@ from __future__ import annotations import unittest +from copy import deepcopy +import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import RandomOrder from monai.transforms.compose import Compose @@ -98,5 +104,41 @@ def test_inverse(self, transform, invertible, use_metatensor): self.assertDictEqual(fwd_data[i], _fwd_inv_data) +TEST_RANDOM_ORDER_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestRandomOrderAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_RANDOM_ORDER_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + c = RandomOrder(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, start=1) + with self.assertRaises(ValueError): + c(data, start=1) + + c = RandomOrder(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, end=1) + with self.assertRaises(ValueError): + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 8c33643d1f..287df039b8 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -19,7 +19,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCrop -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after TEST_CASES = [ @@ -79,18 +79,18 @@ def test_pending_ops(self, input_param, input_shape, _expected_data, align_corne expected = padcropper(image) self.assertIsInstance(expected, MetaTensor) # lazy - padcropper.lazy_evaluation = True + padcropper.lazy = True pending_result = padcropper(image) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms( - pending_result, - mode="nearest", - padding_mode=TESTS_PENDING_MODE[input_param["mode"]], - align_corners=align_corners, - )[0] + overrides = { + "mode": "nearest", + "padding_mode": TESTS_PENDING_MODE[input_param["mode"]], + "align_corners": align_corners, + } + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) inverted = padcropper.inverse(result) diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index a71652375b..471144a609 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -20,7 +20,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import ResizeWithPadOrCropd -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.test_resize_with_pad_or_crop import TESTS_PENDING_MODE from tests.utils import TEST_NDARRAYS_ALL, assert_allclose, pytorch_after @@ -74,15 +74,18 @@ def test_pending_ops(self, input_param, input_data, _expected_data): expected = padcropper(input_data)["img"] self.assertIsInstance(expected, MetaTensor) # lazy - padcropper.lazy_evaluation = True + padcropper.lazy = True pending_result = padcropper(input_data)["img"] self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) # only support nearest - result = apply_transforms( - pending_result, mode="nearest", padding_mode=TESTS_PENDING_MODE[input_param["mode"]], align_corners=True - )[0] + overrides = { + "mode": "nearest", + "padding_mode": TESTS_PENDING_MODE[input_param["mode"]], + "align_corners": True, + } + result = apply_pending(pending_result, overrides=overrides)[0] # compare assert_allclose(result, expected, rtol=1e-5) diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index fd54e7639f..0948469df9 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -18,7 +18,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Affine, Rotate90 -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.utils import optional_import from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( @@ -41,7 +41,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -61,7 +61,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -77,7 +77,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] @@ -93,7 +93,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -111,7 +111,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -127,7 +127,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -143,7 +143,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] @@ -159,7 +159,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -177,16 +177,16 @@ def test_affine_rot90(self, s): def method_0(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, padding_mode="border", align_corners=ac)[0] + out = apply_pending(out, overrides={"padding_mode": "border", "align_corners": ac})[0] return out def method_1(im, ac): xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s) - xform.lazy_evaluation = True + xform.lazy = True out = xform(im) - out = apply_transforms(out, mode=1, padding_mode="nearest", align_corners=ac)[0] + out = apply_pending(out, overrides={"mode": 1, "padding_mode": "nearest", "align_corners": ac})[0] return out def method_2(im, ac): diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 95d475d480..08d3a97498 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -33,7 +33,7 @@ def test_rotate90_default(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -54,7 +54,7 @@ def test_k(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] @@ -71,7 +71,7 @@ def test_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] @@ -88,7 +88,7 @@ def test_prob_k_spatial_axes(self): # test lazy test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key) - rotate.lazy_evaluation = False + rotate.lazy = False test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] diff --git a/tests/test_some_of.py b/tests/test_some_of.py index 0cc903bb2d..8880c376b9 100644 --- a/tests/test_some_of.py +++ b/tests/test_some_of.py @@ -12,9 +12,15 @@ from __future__ import annotations import unittest +from copy import deepcopy +import numpy as np +import torch from parameterized import parameterized +import monai.transforms.intensity.array as ia +import monai.transforms.spatial.array as sa +import monai.transforms.spatial.dictionary as sd from monai.data import MetaTensor from monai.transforms import TraceableTransform, Transform from monai.transforms.compose import Compose, SomeOf @@ -206,5 +212,41 @@ def test_bad_num_transforms(self): self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=("a", 1)) +TEST_SOMEOF_EXTENDED_TEST_CASES = [ + [None, tuple()], + [None, (sa.Rotate(np.pi / 8),)], + [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())], + [("a",), (sd.Rotated(("a",), np.pi / 8),)], +] + + +class TestSomeOfAPITests(unittest.TestCase): + @staticmethod + def data_from_keys(keys): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0) + return data + + @parameterized.expand(TEST_SOMEOF_EXTENDED_TEST_CASES) + def test_execute_change_start_end(self, keys, pipeline): + data = self.data_from_keys(keys) + + c = SomeOf(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, start=1) + with self.assertRaises(ValueError): + c(data, start=1) + + c = SomeOf(deepcopy(pipeline)) + with self.assertRaises(ValueError): + c(data, end=1) + with self.assertRaises(ValueError): + c(data, end=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_combine_transforms.py b/tests/test_spatial_combine_transforms.py index 74c03fc4ff..8594daed16 100644 --- a/tests/test_spatial_combine_transforms.py +++ b/tests/test_spatial_combine_transforms.py @@ -20,7 +20,7 @@ import monai.transforms as mt from monai.data import create_test_image_2d, create_test_image_3d from monai.data.meta_tensor import MetaTensor -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from monai.transforms.transform import MapTransform from monai.utils import set_determinism from tests.lazy_transforms_utils import get_apply_param @@ -162,7 +162,7 @@ def test_combine_transforms(self, input_shape, funcs): # lazy pending_result = input_data for _func in _funcs: - _func.lazy_evaluation = True + _func.lazy = True if isinstance(_func, mt.Randomizable): _func.set_random_state(seed=seed) pending_result = _func(pending_result) @@ -175,7 +175,7 @@ def test_combine_transforms(self, input_shape, funcs): init_param = funcs[-1][1] call_param = {} apply_param = get_apply_param(init_param, call_param) - result = apply_transforms(pending_result, **apply_param)[0] + result = apply_pending(pending_result, overrides=apply_param)[0] match_ratio = np.sum(np.isclose(result.array, expected.array, atol=5e-1)) / np.prod(result.shape) self.assertGreater(match_ratio, 0.5) # at least half of the images are very close diff --git a/tests/test_zoom.py b/tests/test_zoom.py index b614acc9e4..e1ea3c25a3 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -20,7 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Zoom -from monai.transforms.lazy.functional import apply_transforms +from monai.transforms.lazy.functional import apply_pending from tests.utils import ( DEFAULT_TEST_AFFINE, TEST_NDARRAYS_ALL, @@ -53,12 +53,13 @@ def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False): expected = zoom_fn(im) self.assertIsInstance(expected, MetaTensor) # lazy - zoom_fn.lazy_evaluation = True + zoom_fn.lazy = True pending_result = zoom_fn(im) self.assertIsInstance(pending_result, MetaTensor) assert_allclose(pending_result.peek_pending_affine(), expected.affine) assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:]) - result = apply_transforms(pending_result, mode="bilinear", dtype=np.float64, align_corners=align_corners)[0] + overrides = {"mode": "bilinear", "dtype": np.float64, "align_corners": align_corners} + result = apply_pending(pending_result, overrides=overrides)[0] # compare match_ratio = np.sum(np.isclose(result, expected)) / np.prod(result.shape) self.assertGreater(match_ratio, 0.95) diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index a76e43a6b4..1dcbf98572 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -57,7 +57,7 @@ def test_correct_results(self, zoom, mode, keep_size, align_corners=None): test_resampler_lazy( zoom_fn, zoomed, init_param, call_param, output_key=key, atol=1e-4 if USE_COMPILED else 1e-6 ) - zoom_fn.lazy_evaluation = False + zoom_fn.lazy = False test_local_inversion(zoom_fn, zoomed, {key: im}, key) _order = 0