From 00436d0d0caa9e08d3e2a1c596c29bba3e825687 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:37:52 +0800 Subject: [PATCH 01/29] Add `kind` property in `MetaTensor` (#7488) Part of https://github.com/Project-MONAI/MONAI/issues/7486 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/meta_tensor.py | 16 +++++++++++++++- monai/transforms/io/array.py | 2 ++ monai/utils/enums.py | 14 ++++++++++++++ tests/test_meta_tensor.py | 6 ++++-- 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index cad0851a8e..b845c6cd8f 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -24,7 +24,7 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option -from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys +from monai.utils.enums import KindKeys, LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -345,6 +345,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def get_default_affine(dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=torch.device("cpu"), dtype=dtype) + @staticmethod + def get_default_kind() -> str: + return KindKeys.PIXEL + def as_tensor(self) -> torch.Tensor: """ Return the `MetaTensor` as a `torch.Tensor`. @@ -469,6 +473,16 @@ def affine(self, d: NdarrayTensor) -> None: """Set the affine.""" self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) + @property + def kind(self) -> str: + """Get the data kind. Defaults to ``KindKeys.PIXEL``""" + return self.meta.get(MetaKeys.KIND, self.get_default_kind()) # type: ignore + + @kind.setter + def kind(self, d: str) -> None: + """Set the data kind.""" + self.meta[MetaKeys.KIND] = d + @property def pixdim(self): """Get the spacing""" diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7222a26fc3..c9a79e60a6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -46,6 +46,7 @@ from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils.enums import KindKeys, MetaKeys nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -280,6 +281,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) + meta_data[MetaKeys.KIND] = KindKeys.PIXEL img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] if not isinstance(meta_data, dict): raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") diff --git a/monai/utils/enums.py b/monai/utils/enums.py index b786e92151..2624e6dd5a 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -533,6 +533,19 @@ class SpaceKeys(StrEnum): LPS = "LPS" +class KindKeys(StrEnum): + """ + This class provides an effective way to reference data types such as pixel-based data + and point-like data that consists of point coordinates. + + - PIXEL: Represents data that corresponds to pixel-based data. + - POINT: Represents data consisting of the coordinates of points. + """ + + PIXEL = "pixel" + POINT = "point" + + class MetaKeys(StrEnum): """ Typical keys for MetaObj.meta @@ -543,6 +556,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") + KIND = "kind" # possible values of data kind type are defined in `KindKeys` class ColorOrder(StrEnum): diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 1e0f188b63..1815d3ea73 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -32,7 +32,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import decollate_batch, list_data_collate from monai.transforms import BorderPadd, Compose, DivisiblePadd, FromMetaTensord, ToMetaTensord -from monai.utils.enums import PostFix +from monai.utils.enums import KindKeys, PostFix from monai.utils.module import pytorch_after from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda @@ -50,7 +50,6 @@ def rand_string(min_len=5, max_len=10): class TestMetaTensor(unittest.TestCase): - @staticmethod def get_im(shape=None, dtype=None, device=None): if shape is None: @@ -303,6 +302,9 @@ def test_collate(self, device, dtype): self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) self.assertEqual(len(collated.applied_operations), numel) + # data kind + self.assertEqual(collated.kind, KindKeys.PIXEL) + @parameterized.expand(TESTS) def test_dataset(self, device, dtype): ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)] From b2ab07eab2b4603f56651c59aecd1ac3c2900c8c Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 13:05:04 +0100 Subject: [PATCH 02/29] Adding apply_to_geometry function Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 51 ++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 6b95027832..3f2f6221b6 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -17,6 +17,7 @@ from monai.apps.utils import get_logger from monai.config import NdarrayOrTensor +from monai.data.meta_obj import MetaObj from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( @@ -28,7 +29,7 @@ ) from monai.transforms.traits import LazyTrait from monai.transforms.transform import MapTransform -from monai.utils import LazyAttr, look_up_option +from monai.utils import LazyAttr, MetaKeys, convert_to_tensor, look_up_option __all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending"] @@ -293,3 +294,51 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, for p in pending: data.push_applied_operation(p) return data, pending + + +def apply_to_geometry( + data: torch.Tensor, + meta_info: dict | MetaObj = None, + transform: torch.Tensor | None = None, +): + """ + Apply an affine geometric transform or deformation field to geometry. + At present this is limited to the transformation of points. + + The points must be provided as a tensor and must be compatible with a homogeneous + transform. This means that: + - 2D points are of the form (x, y, 1) + - 3D points are of the form (x, y, z, 1) + + The affine transform or deformation field is applied to the the points and a tensor of + the same shape as the input tensor is returned. + + Args: + data: the tensor of points to be transformed. + meta_info: the metadata containing the affine transformation + """ + + if meta_info is None and transform is None: + raise ValueError("either meta_info or transform must be provided") + if meta_info is not None and transform is not None: + raise ValueError("only one of meta_info or transform can be provided") + + if not isinstance(data, (torch.Tensor, MetaTensor)): + raise TypeError(f"data {type(data)} must be a torch.Tensor or MetaTensor") + + data = convert_to_tensor(data, track_meta=get_track_meta()) + + if meta_info is not None: + transform_ = meta_info.meta[MetaKeys.AFFINE] + else: + transform_ = transform + + if transform_.dtype != data.dtype: + transform_ = transform_.to(data.dtype) + + if data.shape[1] != transform_.shape[0]: + raise ValueError(f"second element of data.shape {data.shape} must match transform shape {transform_.shape}") + + result = torch.matmul(data, transform_.T) + + return result \ No newline at end of file From 7c7ec30ad26d436b96e632c6347aa1c92ebdb0b8 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 13:07:34 +0100 Subject: [PATCH 03/29] Adding load_geometry function Signed-off-by: Ben Murray --- monai/transforms/io/functional.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 monai/transforms/io/functional.py diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py new file mode 100644 index 0000000000..a4cefca4c4 --- /dev/null +++ b/monai/transforms/io/functional.py @@ -0,0 +1,29 @@ +import json + + +def load_geometry(file, image, origin): + """ + Load geometry from a file and optionally map it to another coordinate space. + """ + with open(file, "r") as f: + geometry = json.load(f) + geometry_schema = geometry.get("schema", None) + if geometry_schema is None: + raise ValueError("Geometry import issue: missing 'schema' entry") + elif "geometry" not in geometry_schema: + raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}") + + if "points" not in geometry: + raise ValueError("Geometry import issue: missing 'points' entry") + + points = geometry["points"] + if not isinstance(points, list): + raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}") + + if len(points) > 0: + first_len = None + for p in points: + if first_len is None: + first_len = len(p) + if len(p) != first_len: + raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") From 807775903070551a9d2fba1f78f1695c13798824 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 13:10:54 +0100 Subject: [PATCH 04/29] Adding missing import for apply_to_geometry Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 3f2f6221b6..5eb65a1aac 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -17,7 +17,7 @@ from monai.apps.utils import get_logger from monai.config import NdarrayOrTensor -from monai.data.meta_obj import MetaObj +from monai.data.meta_obj import get_track_meta, MetaObj from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( From 7edb44eaf02d2958b905c36b4026935fbae7d314 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 13:11:23 +0100 Subject: [PATCH 05/29] Adding KindKeys to __all__ in monai.utils.enums.py Signed-off-by: Ben Murray --- monai/utils/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 2624e6dd5a..665f14e1cd 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -50,6 +50,7 @@ "GridPatchSort", "FastMRIKeys", "SpaceKeys", + "KindKeys", "MetaKeys", "ColorOrder", "EngineStatsKeys", From 61cb5c38be6ab66007a6206783cd7cce23440076 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 13:14:14 +0100 Subject: [PATCH 06/29] Adding tests for apply_to_geometry Signed-off-by: Ben Murray --- tests/test_apply.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index ca37e945ba..53012b5352 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.transforms.lazy.functional import apply_pending +from monai.transforms.lazy.functional import apply_pending, apply_to_geometry from monai.transforms.utils import create_rotate from monai.utils import LazyAttr, convert_to_tensor from tests.utils import get_arange_img @@ -72,5 +72,22 @@ def test_apply_single_transform_metatensor_override(self): self._test_apply_metatensor_impl(*case, True) +class TestApplyToGeometry(unittest.TestCase): + + def test_apply_to_geometry_2d(self): + t = torch.as_tensor([[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]], dtype=torch.float32) + rot45 = torch.as_tensor(create_rotate(2, np.pi / 4)) + actual = apply_to_geometry(t, transform=rot45) + print(actual) + + def test_apply_to_geometry_3d(self): + t = torch.as_tensor( + [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [0, 1, 0, 1], [1, 1, 0, 1], [1, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1]], + dtype=torch.float32) + rot45 = torch.as_tensor(create_rotate(3, np.pi / 4)) + actual = apply_to_geometry(t, transform=rot45) + print(actual) + + if __name__ == "__main__": unittest.main() From 1e83ec8833c0f0786ea4c9717752d0eebbebc64b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:00:37 +0100 Subject: [PATCH 07/29] Adding apply_to_geometry to __all__ in lazy/functional.py Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 5eb65a1aac..fbcfcda963 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -31,7 +31,7 @@ from monai.transforms.transform import MapTransform from monai.utils import LazyAttr, MetaKeys, convert_to_tensor, look_up_option -__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending"] +__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending", "apply_to_geometry"] __override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"} @@ -341,4 +341,4 @@ def apply_to_geometry( result = torch.matmul(data, transform_.T) - return result \ No newline at end of file + return result From e503d6acdc12e6db6f7fbbc2db1ec3e1bab656a8 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:01:39 +0100 Subject: [PATCH 08/29] Adding flip functionality for geometric_tensors Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 67 +++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5ea..78bca190a0 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,16 +24,18 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor -from monai.data.meta_obj import get_track_meta +from monai.data.meta_obj import get_track_meta, MetaObj from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.croppad.array import ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform +from monai.transforms.lazy.functional import apply_to_geometry from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine from monai.transforms.utils_pytorch_numpy_unification import allclose from monai.utils import ( + KindKeys, LazyAttr, TraceKeys, convert_to_dst_type, @@ -229,7 +231,42 @@ def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> tor return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def flip(img, sp_axes, lazy, transform_info): +# def flip(img, sp_axes, lazy, transform_info): +# """ +# Functional implementation of flip. +# This function operates eagerly or lazily according to +# ``lazy`` (default ``False``). + +# Args: +# img: data to be changed, assuming `img` is channel-first. +# sp_axes: spatial axes along which to flip over. +# If None, will flip over all of the axes of the input array. +# If axis is negative it counts from the last to the first axis. +# If axis is a tuple of ints, flipping is performed on all of the axes +# specified in the tuple. +# lazy: a flag that indicates whether the operation should be performed lazily or not +# transform_info: a dictionary with the relevant information pertaining to an applied transform. +# """ +# sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] +# sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist() +# extra_info = {"axes": sp_axes} # track the spatial axes +# axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim +# rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) +# # axes include the channel dim +# xform = torch.eye(int(rank) + 1, dtype=torch.double) +# for axis in axes: +# sp = axis - 1 +# xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 +# meta_info = TraceableTransform.track_transform_meta( +# img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy +# ) +# out = _maybe_new_metatensor(img) +# if lazy: +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info +# out = torch.flip(out, axes) +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + +def flip_impl(img, sp_axes, lazy, transform_info): """ Functional implementation of flip. This function operates eagerly or lazily according to @@ -258,6 +295,11 @@ def flip(img, sp_axes, lazy, transform_info): meta_info = TraceableTransform.track_transform_meta( img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy ) + return axes, meta_info + + +def flip_raster(img, sp_axes, lazy, transform_info): + axes, meta_info = flip_impl(img, sp_axes, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -265,6 +307,27 @@ def flip(img, sp_axes, lazy, transform_info): return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +def flip_geom(img, sp_axes, lazy, transform_info): + _, meta_info = flip_impl(img, sp_axes, lazy, transform_info) + out = _maybe_new_metatensor(img) + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = apply_to_geometry(out, meta_info) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def flip(image, sp_axes, lazy, transform_info): + """ + Flip the tensor / MetaTensor according to `sp_axes`. + """ + + if isinstance(image, MetaTensor): + if image.kind == KindKeys.RASTER: + return flip_raster(image, sp_axes, lazy, transform_info) + elif image.kind == KindKeys.GEOM: + return flip_geom(image, sp_axes, lazy, transform_info) + + def resize( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ): From f345fc939ba04fc05000e5265aa6dfd6b1400750 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:02:41 +0100 Subject: [PATCH 09/29] Adding KindKeys to utils/__init__ Signed-off-by: Ben Murray --- monai/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..1f4717a4e0 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -42,6 +42,7 @@ HoVerNetMode, InterpolateMode, JITMetadataKeys, + KindKeys, LazyAttr, LossReduction, MetaKeys, From edc4bb78a48128695667010b92ac5a80a8ecfaa1 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:08:41 +0100 Subject: [PATCH 10/29] Adding | None to meta_info for apply_to_geometry Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index fbcfcda963..1513d016ab 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -298,7 +298,7 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, def apply_to_geometry( data: torch.Tensor, - meta_info: dict | MetaObj = None, + meta_info: dict | MetaObj | None = None, transform: torch.Tensor | None = None, ): """ From b8baae992844dfc2988bf6992f71767efa3155d4 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:10:40 +0100 Subject: [PATCH 11/29] Adding resize functionality for geometric tensors Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 145 ++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 78bca190a0..025f8eea67 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -328,7 +328,81 @@ def flip(image, sp_axes, lazy, transform_info): return flip_geom(image, sp_axes, lazy, transform_info) -def resize( +# def resize( +# img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +# ): +# """ +# Functional implementation of resize. +# This function operates eagerly or lazily according to +# ``lazy`` (default ``False``). + +# Args: +# img: data to be changed, assuming `img` is channel-first. +# out_size: expected shape of spatial dimensions after resize operation. +# mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, +# ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} +# The interpolation mode. +# See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html +# align_corners: This only has an effect when mode is +# 'linear', 'bilinear', 'bicubic' or 'trilinear'. +# dtype: data type for resampling computation. If None, use the data type of input data. +# input_ndim: number of spatial dimensions. +# anti_aliasing: whether to apply a Gaussian filter to smooth the image prior +# to downsampling. It is crucial to filter when downsampling +# the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` +# anti_aliasing_sigma: {float, tuple of floats}, optional +# Standard deviation for Gaussian filtering used when anti-aliasing. +# lazy: a flag that indicates whether the operation should be performed lazily or not +# transform_info: a dictionary with the relevant information pertaining to an applied transform. +# """ +# img = convert_to_tensor(img, track_meta=get_track_meta()) +# orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] +# extra_info = { +# "mode": mode, +# "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, +# "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 +# "new_dim": len(orig_size) - input_ndim, +# } +# meta_info = TraceableTransform.track_transform_meta( +# img, +# sp_size=out_size, +# affine=scale_affine(orig_size, out_size), +# extra_info=extra_info, +# orig_size=orig_size, +# transform_info=transform_info, +# lazy=lazy, +# ) +# if lazy: +# if anti_aliasing and lazy: +# warnings.warn("anti-aliasing is not compatible with lazy evaluation.") +# out = _maybe_new_metatensor(img) +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info +# if tuple(convert_to_numpy(orig_size)) == out_size: +# out = _maybe_new_metatensor(img, dtype=torch.float32) +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +# out = _maybe_new_metatensor(img) +# img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor +# if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])): +# factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(out_size)) +# if anti_aliasing_sigma is None: +# # if sigma is not given, use the default sigma in skimage.transform.resize +# anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist() +# else: +# # if sigma is given, use the given value for downsampling axis +# anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(out_size))) +# for axis in range(len(out_size)): +# anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1) +# anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) +# img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) +# _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1) +# resized = torch.nn.functional.interpolate( +# input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners +# ) +# out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32) +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def resize_impl( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ): """ @@ -363,23 +437,36 @@ def resize( "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 "new_dim": len(orig_size) - input_ndim, } + affine = scale_affine(orig_size, out_size) meta_info = TraceableTransform.track_transform_meta( img, sp_size=out_size, - affine=scale_affine(orig_size, out_size), + affine=affine, extra_info=extra_info, orig_size=orig_size, transform_info=transform_info, lazy=lazy, ) + + return affine, orig_size, meta_info + + +def resize_raster( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +): + _, orig_size, meta_info = resize_impl( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) if lazy: if anti_aliasing and lazy: warnings.warn("anti-aliasing is not compatible with lazy evaluation.") out = _maybe_new_metatensor(img) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if tuple(convert_to_numpy(orig_size)) == out_size: out = _maybe_new_metatensor(img, dtype=torch.float32) return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + out = _maybe_new_metatensor(img) img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])): @@ -402,6 +489,60 @@ def resize( return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +def resize_geom( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +): + _, _, meta_info = resize_impl( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) + out = _maybe_new_metatensor(img) + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + + out = apply_to_geometry(out, meta_info) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def resize( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info +): + """ + Functional implementation of resize. + This function operates eagerly or lazily according to + ``lazy`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + out_size: expected shape of spatial dimensions after resize operation. + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. + dtype: data type for resampling computation. If None, use the data type of input data. + input_ndim: number of spatial dimensions. + anti_aliasing: whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + lazy: a flag that indicates whether the operation should be performed lazily or not + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + if isinstance(img, MetaTensor): + if img.kind == KindKeys.PIXEL: + return resize_raster( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) + elif img.kind == KindKeys.GEOMETRY: + return resize_geom( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) + else: + raise ValueError(f"Unsupported value for 'kind': {img.kind}") + + def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of rotate. From e59b4a983a2946086f3ad7cc92bb027565d20ae4 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:13:08 +0100 Subject: [PATCH 12/29] Adding rotate functionality for geometric tensors Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 123 ++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 025f8eea67..1c758bb6ad 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -543,7 +543,78 @@ def resize( raise ValueError(f"Unsupported value for 'kind': {img.kind}") -def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): +# def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): +# """ +# Functional implementation of rotate. +# This function operates eagerly or lazily according to +# ``lazy`` (default ``False``). + +# Args: +# img: data to be changed, assuming `img` is channel-first. +# angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. +# output_shape: output shape of the rotated data. +# mode: {``"bilinear"``, ``"nearest"``} +# Interpolation mode to calculate output values. +# See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html +# padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} +# Padding mode for outside grid values. +# See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html +# align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html +# dtype: data type for resampling computation. +# If None, use the data type of input data. To be compatible with other modules, +# the output data type is always ``float32``. +# lazy: a flag that indicates whether the operation should be performed lazily or not +# transform_info: a dictionary with the relevant information pertaining to an applied transform. + +# """ + +# im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] +# input_ndim = len(im_shape) +# if input_ndim not in (2, 3): +# raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") +# _angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) +# transform = create_rotate(input_ndim, _angle) +# if output_shape is None: +# corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1)) +# corners = transform[:-1, :-1] @ corners # type: ignore +# output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) +# else: +# output_shape = np.asarray(output_shape, dtype=int) +# shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) +# shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist()) +# transform = shift @ transform @ shift_1 +# extra_info = { +# "rot_mat": transform, +# "mode": mode, +# "padding_mode": padding_mode, +# "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, +# "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 +# } +# meta_info = TraceableTransform.track_transform_meta( +# img, +# sp_size=output_shape, +# affine=transform, +# extra_info=extra_info, +# orig_size=im_shape, +# transform_info=transform_info, +# lazy=lazy, +# ) +# out = _maybe_new_metatensor(img) +# if lazy: +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info +# _, _m, _p, _ = resolves_modes(mode, padding_mode) +# xform = AffineTransform( +# normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True +# ) +# img_t = out.to(dtype) +# transform_t, *_ = convert_to_dst_type(transform, img_t) +# output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=tuple(int(i) for i in output_shape)) +# output = output.float().squeeze(0) +# out, *_ = convert_to_dst_type(output, dst=out, dtype=torch.float32) +# return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of rotate. This function operates eagerly or lazily according to @@ -599,6 +670,14 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l transform_info=transform_info, lazy=lazy, ) + return transform, meta_info + + +def rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Raster-specific rotation functionality + """ + transform, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -614,6 +693,48 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +def rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Geometry-specific rotation functionality + """ + _, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + out = _maybe_new_metatensor(img) + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = apply_to_geometry(out, meta_info) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Functional implementation of rotate. + This function operates eagerly or lazily according to + ``lazy`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + output_shape: output shape of the rotated data. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + if isinstance(img, MetaTensor): + if img.kind == KindKeys.PIXEL: + return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + elif img.kind == KindKeys.GEOMETRY: + return rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + + def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of zoom. From a0b768efbb59f71e3dd118611b736a882dc0914f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 11 Jul 2024 14:17:53 +0100 Subject: [PATCH 13/29] Fixing line endings for monai/transforms/io/functional.py Signed-off-by: Ben Murray --- monai/transforms/io/functional.py | 58 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py index a4cefca4c4..83487e6261 100644 --- a/monai/transforms/io/functional.py +++ b/monai/transforms/io/functional.py @@ -1,29 +1,29 @@ -import json - - -def load_geometry(file, image, origin): - """ - Load geometry from a file and optionally map it to another coordinate space. - """ - with open(file, "r") as f: - geometry = json.load(f) - geometry_schema = geometry.get("schema", None) - if geometry_schema is None: - raise ValueError("Geometry import issue: missing 'schema' entry") - elif "geometry" not in geometry_schema: - raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}") - - if "points" not in geometry: - raise ValueError("Geometry import issue: missing 'points' entry") - - points = geometry["points"] - if not isinstance(points, list): - raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}") - - if len(points) > 0: - first_len = None - for p in points: - if first_len is None: - first_len = len(p) - if len(p) != first_len: - raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") +import json + + +def load_geometry(file, image, origin): + """ + Load geometry from a file and optionally map it to another coordinate space. + """ + with open(file, "r") as f: + geometry = json.load(f) + geometry_schema = geometry.get("schema", None) + if geometry_schema is None: + raise ValueError("Geometry import issue: missing 'schema' entry") + elif "geometry" not in geometry_schema: + raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}") + + if "points" not in geometry: + raise ValueError("Geometry import issue: missing 'points' entry") + + points = geometry["points"] + if not isinstance(points, list): + raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}") + + if len(points) > 0: + first_len = None + for p in points: + if first_len is None: + first_len = len(p) + if len(p) != first_len: + raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") From 4e6ae70ffa76037b5576071023fe826cc73a52a6 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Jul 2024 13:42:17 +0100 Subject: [PATCH 14/29] Fixed rotate to that output_shape is returned. Made resize consistent with rotate Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 1c758bb6ad..9bb2cbbb70 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -448,13 +448,13 @@ def resize_impl( lazy=lazy, ) - return affine, orig_size, meta_info + return affine, meta_info, orig_size def resize_raster( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ): - _, orig_size, meta_info = resize_impl( + _, meta_info, orig_size = resize_impl( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ) if lazy: @@ -492,7 +492,7 @@ def resize_raster( def resize_geom( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ): - _, _, meta_info = resize_impl( + _1, meta_info, _2 = resize_impl( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ) out = _maybe_new_metatensor(img) @@ -535,7 +535,7 @@ def resize( return resize_raster( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ) - elif img.kind == KindKeys.GEOMETRY: + elif img.kind == KindKeys.POINT: return resize_geom( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ) @@ -670,14 +670,14 @@ def rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dty transform_info=transform_info, lazy=lazy, ) - return transform, meta_info + return transform, meta_info, output_shape def rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Raster-specific rotation functionality """ - transform, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + transform, meta_info, output_shape = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -697,7 +697,7 @@ def rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dty """ Geometry-specific rotation functionality """ - _, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + _1, meta_info, _2 = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -731,7 +731,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l if isinstance(img, MetaTensor): if img.kind == KindKeys.PIXEL: return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) - elif img.kind == KindKeys.GEOMETRY: + elif img.kind == KindKeys.POINT: return rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) From f763c8f03f9e38151912efcb22a9ce356b20d459 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Jul 2024 13:44:49 +0100 Subject: [PATCH 15/29] Work towards geometry tests for rotate Signed-off-by: Ben Murray --- tests/test_rotate.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 19fbd1409f..edbb5d9956 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -21,6 +21,7 @@ from monai.config import USE_COMPILED from monai.data import MetaTensor, set_track_meta from monai.transforms import Rotate +from monai.utils.enums import KindKeys from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import HAS_CUPY, TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion @@ -89,6 +90,16 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") + def test_pure_geometry(self): + rotate_fn = Rotate(np.pi / 2, True) + geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = rotate_fn(geom) + expected = torch.tensor([[[0, 0, 1], [-1, 0, 1], [0, 1, 1], [-1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + class TestRotate3D(NumpyImageTestCase3D): From fc8e0b8b07c02c39b5f6102fc6359c5dedf22176 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 19 Jul 2024 13:54:44 +0100 Subject: [PATCH 16/29] Fixed KindKey types for flip functionality Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 9bb2cbbb70..a052d4f82d 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -322,9 +322,9 @@ def flip(image, sp_axes, lazy, transform_info): """ if isinstance(image, MetaTensor): - if image.kind == KindKeys.RASTER: + if image.kind == KindKeys.PIXEL: return flip_raster(image, sp_axes, lazy, transform_info) - elif image.kind == KindKeys.GEOM: + elif image.kind == KindKeys.POINT: return flip_geom(image, sp_axes, lazy, transform_info) From 0857ac7cae85e3fea0ec2fbc65cab4c0a748e3cf Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 26 Jul 2024 15:21:23 +0100 Subject: [PATCH 17/29] Fix to handle 2d data being multiplied by a 3d transform from the metatensor api Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 1513d016ab..d77070f21b 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -335,9 +335,14 @@ def apply_to_geometry( if transform_.dtype != data.dtype: transform_ = transform_.to(data.dtype) - - if data.shape[1] != transform_.shape[0]: - raise ValueError(f"second element of data.shape {data.shape} must match transform shape {transform_.shape}") + if data.shape[-1] == 3 and transform_.shape[0] == 4: + transform_[2, 0:2] = transform_[3, 0:2] + transform_[2, 2] = transform_[3, 3] + transform_[0:2, 2] = transform_[0:2, 3] + transform_ = transform_[:-1, :-1] + + if data.shape[-1] != transform_.shape[0]: + raise ValueError(f"final element of data.shape {data.shape} must match transform shape {transform_.shape}") result = torch.matmul(data, transform_.T) From aa8392691a51f3b6b8da9388d6a3581d3e9ed5fa Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 26 Jul 2024 15:24:33 +0100 Subject: [PATCH 18/29] Bug fixes to make all related unit tests pass Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 57 ++++++++++++++------------ 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index a052d4f82d..d795156cbd 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -291,7 +291,9 @@ def flip_impl(img, sp_axes, lazy, transform_info): xform = torch.eye(int(rank) + 1, dtype=torch.double) for axis in axes: sp = axis - 1 - xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 + xform[sp, sp] = xform[sp, sp] * -1 + if not isinstance(img, MetaTensor) or img.kind == KindKeys.PIXEL: + xform[sp, -1] = sp_size[sp] - 1 meta_info = TraceableTransform.track_transform_meta( img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy ) @@ -321,11 +323,10 @@ def flip(image, sp_axes, lazy, transform_info): Flip the tensor / MetaTensor according to `sp_axes`. """ - if isinstance(image, MetaTensor): - if image.kind == KindKeys.PIXEL: - return flip_raster(image, sp_axes, lazy, transform_info) - elif image.kind == KindKeys.POINT: - return flip_geom(image, sp_axes, lazy, transform_info) + if isinstance(image, MetaTensor) and image.kind == KindKeys.POINT: + return flip_geom(image, sp_axes, lazy, transform_info) + else: + return flip_raster(image, sp_axes, lazy, transform_info) # def resize( @@ -530,17 +531,14 @@ def resize( lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - if isinstance(img, MetaTensor): - if img.kind == KindKeys.PIXEL: - return resize_raster( - img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info - ) - elif img.kind == KindKeys.POINT: - return resize_geom( - img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info - ) - else: - raise ValueError(f"Unsupported value for 'kind': {img.kind}") + if isinstance(img, MetaTensor) and img.kind == KindKeys.POINT: + return resize_geom( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) + else: + return resize_raster( + img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info + ) # def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): @@ -640,7 +638,11 @@ def rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dty """ im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - input_ndim = len(im_shape) + if isinstance(img, MetaTensor) and img.kind == KindKeys.POINT: + input_ndim = img.shape[-1] - 1 + else: + input_ndim = len(img.shape) - 1 + if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") _angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) @@ -651,9 +653,13 @@ def rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dty output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) else: output_shape = np.asarray(output_shape, dtype=int) - shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) - shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist()) - transform = shift @ transform @ shift_1 + # TODO: this is needed for the raster case but not for geometry + if not isinstance(img, MetaTensor) or img.kind == KindKeys.PIXEL: + shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) + shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist()) + transform = shift @ transform @ shift_1 + else: + transform = transform extra_info = { "rot_mat": transform, "mode": mode, @@ -728,11 +734,10 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - if isinstance(img, MetaTensor): - if img.kind == KindKeys.PIXEL: - return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) - elif img.kind == KindKeys.POINT: - return rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + if isinstance(img, MetaTensor) and img.kind == KindKeys.POINT: + return rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) + else: + return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): From 55edf777eef4c3b73a2d300a6fdabb48939d010b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 26 Jul 2024 15:26:35 +0100 Subject: [PATCH 19/29] Removed ndims (not ndim) from MetaTensor Signed-off-by: Ben Murray --- monai/data/meta_tensor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index b845c6cd8f..119581cf06 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -490,6 +490,13 @@ def pixdim(self): return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine) + # @property + # def ndims(self): + # # TODO: this will be wrong when there are batches; review + # if self.kind == KindKeys.POINT: + # return self.shape[2] - 1 + # return len(self.shape) + def peek_pending_shape(self): """ Get the currently expected spatial shape as if all the pending operations are executed. From 7928f3246f4cf258030938a03dff449db22bd0da Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 26 Jul 2024 15:27:08 +0100 Subject: [PATCH 20/29] Adding tests for rotate / flip Signed-off-by: Ben Murray --- tests/test_flip.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_rotate.py | 26 +++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/tests/test_flip.py b/tests/test_flip.py index 789ec86920..df26fd751e 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -20,6 +20,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Flip +from monai.utils.enums import KindKeys from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion @@ -73,6 +74,85 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): with self.assertRaisesRegex(ValueError, "MetaTensor"): xform.inverse(res) + def test_pure_geometry_flip_2d_x(self): + flip_fn = Flip(spatial_axis=[0]) + geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 1], [0, 1, 1], [-1, 0, 1], [-1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_flip_2d_y(self): + flip_fn = Flip(spatial_axis=[1]) + geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 1], [0, -1, 1], [1, 0, 1], [1, -1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_flip_2d_all(self): + flip_fn = Flip(spatial_axis=[0, 1]) + geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 1], [0, -1, 1], [-1, 0, 1], [-1, -1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_flip_3d_x(self): + flip_fn = Flip(spatial_axis=[0]) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [-1, 0, 0, 1], [0, 1, 0, 1], [-1, 1, 0, 1], + [0, 0, 1, 1], [-1, 0, 1, 1], [0, 1, 1, 1], [-1, 1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_flip_3d_y(self): + flip_fn = Flip(spatial_axis=[1]) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, -1, 0, 1], [1, -1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, -1, 1, 1], [1, -1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_flip_3d_z(self): + flip_fn = Flip(spatial_axis=[2]) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, -1, 1], [1, 0, -1, 1], [0, 1, -1, 1], [1, 1, -1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + + def test_pure_geometry_flip_3d_all(self): + flip_fn = Flip(spatial_axis=[0, 1, 2]) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = flip_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [-1, 0, 0, 1], [0, -1, 0, 1], [-1, -1, 0, 1], + [0, 0, -1, 1], [-1, 0, -1, 1], [0, -1, -1, 1], [-1, -1, -1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rotate.py b/tests/test_rotate.py index edbb5d9956..a627d94d42 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -90,7 +90,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") - def test_pure_geometry(self): + def test_pure_geometry_2d(self): rotate_fn = Rotate(np.pi / 2, True) geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) geom = MetaTensor(geom, requires_grad=False) @@ -162,6 +162,30 @@ def test_ill_case(self): with self.assertRaises(ValueError): # wrong mode rotate_fn(p(self.imt[0]), mode="trilinear_spell_error") + def test_pure_geometry_3d_x(self): + rotate_fn = Rotate((np.pi / 2, 0, 0), True) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = rotate_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], + [0, -1, 0, 1], [1, -1, 0, 1], [0, -1, 1, 1], [1, -1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_3d_z(self): + rotate_fn = Rotate((0, 0, np.pi / 2), True) + geom = torch.tensor([[[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], [1, 1, 0, 1], + [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = rotate_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [0, 1, 0, 1], [-1, 0, 0, 1], [-1, 1, 0, 1], + [0, 0, 1, 1], [0, 1, 1, 1], [-1, 0, 1, 1], [-1, 1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + if __name__ == "__main__": unittest.main() From 1a216e9553eba37233456000525b5516844ad547 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:30:02 +0000 Subject: [PATCH 21/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/functional.py | 4 ++-- monai/transforms/spatial/functional.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py index 83487e6261..4d7723b3ff 100644 --- a/monai/transforms/io/functional.py +++ b/monai/transforms/io/functional.py @@ -5,7 +5,7 @@ def load_geometry(file, image, origin): """ Load geometry from a file and optionally map it to another coordinate space. """ - with open(file, "r") as f: + with open(file) as f: geometry = json.load(f) geometry_schema = geometry.get("schema", None) if geometry_schema is None: @@ -26,4 +26,4 @@ def load_geometry(file, image, origin): if first_len is None: first_len = len(p) if len(p) != first_len: - raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") + raise ValueError("Geometry import issue: 'points' entry contains inconsistent point lengths") diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index d795156cbd..48c5c2dcbf 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,7 +24,7 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor -from monai.data.meta_obj import get_track_meta, MetaObj +from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform From 700459880377956916147d5ba9236ee5e695fbf9 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 10:43:03 +0100 Subject: [PATCH 22/29] load_geometry functionality and tests Signed-off-by: Ben Murray --- monai/transforms/io/functional.py | 60 ++++++++++++++--------- tests/test_load_geometry.py | 81 +++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 22 deletions(-) create mode 100644 tests/test_load_geometry.py diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py index 83487e6261..f2ce405ec9 100644 --- a/monai/transforms/io/functional.py +++ b/monai/transforms/io/functional.py @@ -1,29 +1,45 @@ import json +import numpy as np + +import torch + +from monai.data.meta_tensor import MetaTensor + +from monai.utils.enums import KindKeys + def load_geometry(file, image, origin): """ Load geometry from a file and optionally map it to another coordinate space. """ - with open(file, "r") as f: - geometry = json.load(f) - geometry_schema = geometry.get("schema", None) - if geometry_schema is None: - raise ValueError("Geometry import issue: missing 'schema' entry") - elif "geometry" not in geometry_schema: - raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}") - - if "points" not in geometry: - raise ValueError("Geometry import issue: missing 'points' entry") - - points = geometry["points"] - if not isinstance(points, list): - raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}") - - if len(points) > 0: - first_len = None - for p in points: - if first_len is None: - first_len = len(p) - if len(p) != first_len: - raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") + + geometry = json.load(file) + geometry_schema = geometry.get("schema", None) + if geometry_schema is None: + raise ValueError("Geometry import issue: missing 'schema' entry") + elif "geometry" not in geometry_schema: + raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}") + + if "points" not in geometry: + raise ValueError("Geometry import issue: missing 'points' entry") + + points = geometry["points"] + if not isinstance(points, list): + raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}") + + if len(points) > 0: + first_len = None + for p in points: + if first_len is None: + first_len = len(p) + if len(p) != first_len: + raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") + + points = np.asarray(points) + points = np.concatenate((points, np.ones((points.shape[0], 1))), axis=1) + points = torch.as_tensor(points, dtype=torch.float32) + points = MetaTensor(points) + points.kind = KindKeys.POINT + + return points diff --git a/tests/test_load_geometry.py b/tests/test_load_geometry.py new file mode 100644 index 0000000000..f94bd753f6 --- /dev/null +++ b/tests/test_load_geometry.py @@ -0,0 +1,81 @@ +import unittest + +import json +from io import StringIO + +import torch + +from monai.transforms.io.functional import load_geometry +from monai.utils.enums import KindKeys + + +class TestLoadGeometry(unittest.TestCase): + + def test_load_geometry_2d(self): + + entry = { + "schema": { + "geometry": "point" + }, + "points": [ + [0, 0], + [1, 0], + [0, 1], + [1, 1] + ] + } + + expected = torch.tensor( + [[0., 0., 1.], [1., 0., 1.], [0., 1., 1.], [1., 1., 1.]], + dtype=torch.float32 + ) + file = StringIO(json.dumps(entry)) + + points = load_geometry(file, None, None) + + self.assertEqual(points.shape, (4, 3)) + + self.assertEqual(points.kind, KindKeys.POINT) + + self.assertEqual(points.dtype, torch.float32) + + self.assertTrue(torch.allclose(points.data, expected)) + + + def test_load_geometry_3d(self): + + entry = { + "schema": { + "geometry": "point" + }, + "points": [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 1, 1], + [1, 0, 0], + [1, 1, 0], + [1, 0, 1], + [1, 1, 1], + ] + } + + expected = torch.tensor( + [[0., 0., 0., 1.], [0., 1., 0., 1.], [0., 0., 1., 1.], [0., 1., 1., 1.], + [1., 0., 0., 1.], [1., 1., 0., 1.], [1., 0., 1., 1.], [1., 1., 1., 1.]], + dtype=torch.float32 + ) + file = StringIO(json.dumps(entry)) + + points = load_geometry(file, None, None) + + self.assertEqual(points.shape, (8, 4)) + + self.assertEqual(points.kind, KindKeys.POINT) + + self.assertEqual(points.dtype, torch.float32) + + self.assertTrue(torch.allclose(points.data, expected)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 62f506bc15a1af30a57a4a23d322f346afb6ef24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 09:46:01 +0000 Subject: [PATCH 23/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/functional.py | 2 +- tests/test_load_geometry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py index 6cdfbbece5..9f35709ad1 100644 --- a/monai/transforms/io/functional.py +++ b/monai/transforms/io/functional.py @@ -33,7 +33,7 @@ def load_geometry(file, image, origin): if first_len is None: first_len = len(p) if len(p) != first_len: - raise ValueError(f"Geometry import issue: 'points' entry contains inconsistent point lengths") + raise ValueError("Geometry import issue: 'points' entry contains inconsistent point lengths") points = np.asarray(points) points = np.concatenate((points, np.ones((points.shape[0], 1))), axis=1) diff --git a/tests/test_load_geometry.py b/tests/test_load_geometry.py index f94bd753f6..710473db97 100644 --- a/tests/test_load_geometry.py +++ b/tests/test_load_geometry.py @@ -78,4 +78,4 @@ def test_load_geometry_3d(self): self.assertTrue(torch.allclose(points.data, expected)) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 4740533ca683a132e5dce1e639e9e6ddddb5b910 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 12:08:15 +0100 Subject: [PATCH 24/29] Adding spatial_dims_from_tensorlike function to handle point tensors Signed-off-by: Ben Murray --- monai/transforms/utils.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e282ecff24..e35668dc7e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -26,6 +26,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.meta_tensor import MetaTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -65,7 +66,7 @@ optional_import, pytorch_after, ) -from monai.utils.enums import TransformBackends +from monai.utils.enums import KindKeys, TransformBackends from monai.utils.type_conversion import ( convert_data_type, convert_to_cupy, @@ -94,6 +95,7 @@ "create_scale", "create_shear", "create_translate", + "spatial_dims_from_tensorlike", "extreme_points_to_image", "fill_holes", "Fourier", @@ -953,6 +955,22 @@ def _create_translate( return array_func(affine) # type: ignore +def spatial_dims_from_tensorlike(data: NdarrayOrTensor) -> int: + """ + Get the spatial dimensions of the input data. + + Args: + data: input data to infer the spatial dimensions. + + Returns: + spatial dimensions of the input data. + + """ + if isinstance(data, MetaTensor) and data.kind == KindKeys.POINT: + return data.shape[-1] - 1 + return len(data.shape) - 1 + + @deprecated_arg_default("allow_smaller", old_default=True, new_default=False, since="1.2", replaced="1.5") def generate_spatial_bounding_box( img: NdarrayOrTensor, From 2ff1d73d91a2828aeb40264df108875130075db6 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 12:09:05 +0100 Subject: [PATCH 25/29] Zoom modified to use spatial_dims_from_tensorlike Signed-off-by: Ben Murray --- monai/transforms/spatial/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 094afdd3c4..7a358ec6a1 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -54,6 +54,7 @@ map_spatial_axes, resolves_modes, scale_affine, + spatial_dims_from_tensorlike, ) from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis from monai.utils import ( @@ -1089,7 +1090,8 @@ def __call__( during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + spatial_dims = spatial_dims_from_tensorlike(img) + _zoom = ensure_tuple_rep(self.zoom, spatial_dims) # match the spatial image dim _mode = self.mode if mode is None else mode _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners From 645e0c6c04eb4d0472a4a9caa3cc469fda120da6 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 12:10:02 +0100 Subject: [PATCH 26/29] Added and tested point zoom functionality Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 63 +++++++++++++++++++++++++- tests/test_zoom.py | 24 ++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 48c5c2dcbf..ba4816547a 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -740,7 +740,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info) -def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): +def zoom_impl(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of zoom. This function operates eagerly or lazily according to @@ -768,7 +768,13 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, """ im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] output_size = [int(math.floor(float(i) * z)) for i, z in zip(im_shape, scale_factor)] - xform = scale_affine(im_shape, output_size) + if not isinstance(img, MetaTensor) or img.kind == KindKeys.PIXEL: + xform = scale_affine(im_shape, output_size) + else: + spatial_dims = im_shape[1] - 1 + old_shape = [1 for _ in range(spatial_dims)] + + xform = scale_affine(old_shape, scale_factor, False) extra_info = { "mode": mode, "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, @@ -798,6 +804,14 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, transform_info=transform_info, lazy=lazy, ) + return xform, meta_info, output_size + + +def zoom_raster(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Raster-specific zoom functionality + """ + transform, meta_info, output_size = zoom_impl(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -824,6 +838,51 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, return out +def zoom_geom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Geometry-specific zoom functionality + """ + transform, meta_info, output_size = zoom_impl(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info) + out = _maybe_new_metatensor(img) + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = apply_to_geometry(out, meta_info) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + return out + + +def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info): + """ + Functional implementation of zoom. + This function operates eagerly or lazily according to + ``lazy`` (default ``False``). + + Args: + img: data to be changed, assuming `img` is channel-first. + scale_factor: The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + keep_size: Whether keep original size (padding/slicing if needed). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``float32``. + lazy: a flag that indicates whether the operation should be performed lazily or not + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + if isinstance(img, MetaTensor) and img.kind == KindKeys.POINT: + return zoom_geom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info) + else: + return zoom_raster(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info) + + def rotate90(img, axes, k, lazy, transform_info): """ Functional implementation of rotate90. diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 2db2df4486..8d2532d504 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -28,6 +28,8 @@ assert_allclose, test_local_inversion, ) +from monai.utils.enums import KindKeys + VALID_CASES = [ (1.5, "nearest", True), @@ -116,6 +118,28 @@ def test_padding_mode(self): expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) assert_allclose(zoomed, expected, type_test=False) + def test_pure_geometry_2d(self): + zoom_fn = Zoom(2) + geom = torch.tensor([[[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = zoom_fn(geom) + expected = torch.tensor([[[0, 0, 1], [0.5, 0, 1], [0, 0.5, 1], [0.5, 0.5, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + + def test_pure_geometry_3d(self): + zoom_fn = Zoom(2) + geom = torch.tensor([[[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 0, 1], [0, 1, 1, 1], + [1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = zoom_fn(geom) + expected = torch.tensor([[[0, 0, 0, 1], [0, 0, 0.5, 1], [0, 0.5, 0, 1], [0, 0.5, 0.5, 1], + [0.5, 0, 0, 1], [0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], [0.5, 0.5, 0.5, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + if __name__ == "__main__": unittest.main() From f2e24267ea4d4206a06dd7c48fa1cc2aad59df9d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 14:22:31 +0100 Subject: [PATCH 27/29] Adding traced_no_op function for use in transforms that no-op geometry data (like Spacing) Signed-off-by: Ben Murray --- monai/transforms/utility/functional.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 monai/transforms/utility/functional.py diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py new file mode 100644 index 0000000000..5f5be57808 --- /dev/null +++ b/monai/transforms/utility/functional.py @@ -0,0 +1,27 @@ +import torch + +from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor +from monai.transforms.utils import convert_to_tensor, create_translate, spatial_dims_from_tensorlike +from monai.transforms.inverse import TraceableTransform + +def traced_no_op(data, lazy, transform_info): + data = convert_to_tensor(data, track_meta=get_track_meta()) + spatial_dims = spatial_dims_from_tensorlike(data) + meta_info = TraceableTransform.track_transform_meta( + data, + sp_size=None, + affine=create_translate(spatial_dims, [0] * spatial_dims), + extra_info={}, + orig_size=None, + transform_info=transform_info, + lazy=lazy, + ) + # TODO: what is this for? If it is needed, it should be part of utility transforms, + # not spatial transforms + # out = _maybe_new_metatensor(data) + out = data + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + out = torch.clone(data) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out From 4898b881f997ce244fa33dc8e6d3838d697703b9 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 14:22:57 +0100 Subject: [PATCH 28/29] Adding geometry support to Spacing transform Signed-off-by: Ben Murray --- monai/transforms/spatial/array.py | 10 ++++++++-- tests/test_spacing.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7a358ec6a1..5df78a2f80 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -44,6 +44,7 @@ ) from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform +from monai.transforms.utility.functional import traced_no_op from monai.transforms.utils import ( create_control_grid, create_grid, @@ -73,7 +74,7 @@ issequenceiterable, optional_import, ) -from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends +from monai.utils.enums import GridPatchSort, KindKeys, PatchKeys, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -483,6 +484,12 @@ def __call__( data tensor or MetaTensor (resampled into `self.pixdim`). """ + lazy_ = self.lazy if lazy is None else lazy + if isinstance(data_array, MetaTensor) and data_array.kind == KindKeys.POINT: + warnings.warn("Spacing transform is not applied to point data.") + data_array = traced_no_op(data_array, lazy_, self.get_transform_info()) + return data_array + original_spatial_shape = ( data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:] ) @@ -522,7 +529,6 @@ def __call__( new_affine[:sr, -1] = offset[:sr] actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape - lazy_ = self.lazy if lazy is None else lazy data_array = self.sp_resample( data_array, dst_affine=torch.as_tensor(new_affine), diff --git a/tests/test_spacing.py b/tests/test_spacing.py index c9a6291c78..a4803ba047 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -22,7 +22,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import affine_to_spacing from monai.transforms import Spacing -from monai.utils import fall_back_tuple +from monai.utils import fall_back_tuple, KindKeys from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick @@ -378,5 +378,15 @@ def test_property_no_change(self): assert_allclose(tr.pixdim, [1.0, 1.0, 1.0], type_test=False) + def test_pure_geometry_2d(self): + spacing_fn = Spacing([1, 1]) + geom = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + geom = MetaTensor(geom, requires_grad=False) + geom.kind = KindKeys.POINT + actual = spacing_fn(geom) + expected = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]]], dtype=torch.float32) + # also test inversion + self.assertTrue(torch.allclose(actual.data, expected.data)) + if __name__ == "__main__": unittest.main() From 2cda79aa29daf81df65d0af1257f94b95d357688 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 2 Aug 2024 14:49:52 +0100 Subject: [PATCH 29/29] Adding save_geometry function plus test Signed-off-by: Ben Murray --- monai/transforms/io/functional.py | 40 +++++++++++++++++++++++++++++++ tests/test_load_geometry.py | 37 +++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/monai/transforms/io/functional.py b/monai/transforms/io/functional.py index 9f35709ad1..458ce0011b 100644 --- a/monai/transforms/io/functional.py +++ b/monai/transforms/io/functional.py @@ -42,3 +42,43 @@ def load_geometry(file, image, origin): points.kind = KindKeys.POINT return points + +""" +{ + "schema": { + "geometry": "point" + }, + "points": [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 1, 1], + [1, 0, 0], + [1, 1, 0], + [1, 0, 1], + [1, 1, 1], + ] + } +""" + +def save_geometry(data, file, image, origin): + """ + Load geometry from a file and optionally map it to another coordinate space. + """ + if not isinstance(data, MetaTensor): + raise ValueError(f"Geometry export issue: data must be a MetaTensor, got: {type(data)}") + if data.kind != KindKeys.POINT: + raise ValueError(f"Geometry export issue: geometry must be a point {KindKeys.POINT}") + geometry = data.detach().cpu().numpy() + geometry = geometry[:, :-1].tolist() + + schema = { + "schema": { + "geometry": "point" + }, + "points": + geometry + } + + geometry = json.dump(schema, file) + return None diff --git a/tests/test_load_geometry.py b/tests/test_load_geometry.py index 710473db97..2a9a3f6b23 100644 --- a/tests/test_load_geometry.py +++ b/tests/test_load_geometry.py @@ -5,7 +5,8 @@ import torch -from monai.transforms.io.functional import load_geometry +from monai.data.meta_tensor import MetaTensor +from monai.transforms.io.functional import load_geometry, save_geometry from monai.utils.enums import KindKeys @@ -77,5 +78,39 @@ def test_load_geometry_3d(self): self.assertTrue(torch.allclose(points.data, expected)) + +class TestSaveGeometry(unittest.TestCase): + + def test_save_geometry_2d(self): + + entry = { + "schema": { + "geometry": "point" + }, + "points": [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0] + ] + } + + padded_points = [[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]] + data = torch.as_tensor(padded_points, dtype=torch.float32) + data = MetaTensor(data) + data.kind = KindKeys.POINT + + # expected = torch.tensor( + # [[0., 0., 1.], [1., 0., 1.], [0., 1., 1.], [1., 1., 1.]], + # dtype=torch.float32 + # ) + file = StringIO() + + save_geometry(data, file, None, None) + actual = file.getvalue() + expected = json.dumps(entry) + self.assertEqual(actual, expected) + + if __name__ == '__main__': unittest.main()