Skip to content

Commit

Permalink
Inverse Affined and RandAffined (#1781)
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Brown <[email protected]>
  • Loading branch information
rijobro committed Mar 18, 2021
1 parent f063a47 commit 9e65010
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 22 deletions.
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
ThresholdIntensityD,
ThresholdIntensityDict,
)
from .inverse import InvertibleTransform
from .inverse import InvertibleTransform, NonRigidTransform
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .post.array import (
Expand Down
131 changes: 131 additions & 0 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Dict, Hashable, Optional, Tuple

import numpy as np
import torch

from monai.transforms.transform import RandomizableTransform, Transform
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import

sitk, has_sitk = optional_import("SimpleITK")
vtk, has_vtk = optional_import("vtk")
vtk_numpy_support, _ = optional_import("vtk.util.numpy_support")

__all__ = ["InvertibleTransform"]

Expand Down Expand Up @@ -111,3 +118,127 @@ def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]:
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class NonRigidTransform(Transform):
@staticmethod
def _get_disp_to_def_arr(shape, spacing):
def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64)
for idx, i in enumerate(shape):
# shift for origin (in MONAI, center of image)
def_to_disp[idx] -= (i - 1) / 2
# if supplied, account for spacing (e.g., for control point grids)
if spacing is not None:
def_to_disp[idx] *= spacing[idx]
return def_to_disp

@staticmethod
def _inv_disp_w_sitk(fwd_disp, num_iters):
fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True)
inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters)
inv_disp = sitk.GetArrayFromImage(inv_disp_sitk)
return inv_disp

@staticmethod
def _inv_disp_w_vtk(fwd_disp):
orig_shape = fwd_disp.shape
required_num_tensor_components = 3
# VTK requires 3 tensor components, so if shape was (H, W, 2), make it
# (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s)
while fwd_disp.shape[-1] < required_num_tensor_components:
fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1)
fwd_disp = fwd_disp[..., None, :]

# Create VTKDoubleArray. Shape needs to be (H*W*D, 3)
fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory
vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened)

# Generating the vtkImageData
fwd_disp_vtk = vtk.vtkImageData()
fwd_disp_vtk.SetOrigin(0, 0, 0)
fwd_disp_vtk.SetSpacing(1, 1, 1)
fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy
fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array)

if __debug__:
fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0))
assert fwd_disp_vtk_np.size == fwd_disp.size
assert fwd_disp_vtk_np.min() == fwd_disp.min()
assert fwd_disp_vtk_np.max() == fwd_disp.max()
assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components

# create b-spline coefficients for the displacement grid
bspline_filter = vtk.vtkImageBSplineCoefficients()
bspline_filter.SetInputData(fwd_disp_vtk)
bspline_filter.Update()

# use these b-spline coefficients to create a transform
bspline_transform = vtk.vtkBSplineTransform()
bspline_transform.SetCoefficientData(bspline_filter.GetOutput())
bspline_transform.Update()

# invert the b-spline transform onto a new grid
grid_maker = vtk.vtkTransformToGrid()
grid_maker.SetInput(bspline_transform.GetInverse())
grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin())
grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing())
grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent())
grid_maker.SetGridScalarTypeToFloat()
grid_maker.Update()

# Get inverse displacement as an image
inv_disp_vtk = grid_maker.GetOutput()

# Convert back to numpy and reshape
inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0))
# if there were originally < 3 tensor components, remove the zeros we added at the start
inv_disp = inv_disp[..., : orig_shape[-1]]
# reshape to original
inv_disp = inv_disp.reshape(orig_shape)

return inv_disp

@staticmethod
def compute_inverse_deformation(
num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk"
):
"""Package can be vtk or sitk."""
if use_package.lower() == "vtk" and not has_vtk:
warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified")
return None
if use_package.lower() == "sitk" and not has_sitk:
warnings.warn(
"Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified"
)
return None

# Convert to numpy if necessary
if isinstance(fwd_def_orig, torch.Tensor):
fwd_def_orig = fwd_def_orig.cpu().numpy()
# Remove any extra dimensions (we'll add them back in at the end)
fwd_def = fwd_def_orig[:num_spatial_dims]
# Def -> disp
def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing)
fwd_disp = fwd_def - def_to_disp
# move tensor component to end (T,H,W,[D])->(H,W,[D],T)
fwd_disp = np.moveaxis(fwd_disp, 0, -1)

# If using vtk...
if use_package.lower() == "vtk":
inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp)
# If using sitk...
elif use_package.lower() == "sitk":
inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters)
else:
raise RuntimeError("Enter vtk or sitk for inverse calculation")

# move tensor component back to beginning
inv_disp = np.moveaxis(inv_disp, -1, 0)
# Disp -> def
inv_def = inv_disp + def_to_disp
# Add back in any removed dimensions
ndim_in = fwd_def_orig.shape[0]
ndim_out = inv_def.shape[0]
inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]])

return inv_def
126 changes: 107 additions & 19 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import AffineTransform
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse import InvertibleTransform, NonRigidTransform
from monai.transforms.spatial.array import (
Affine,
AffineGrid,
Expand All @@ -50,9 +50,9 @@
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
optional_import,
)
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -730,7 +730,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class Rand2DElasticd(RandomizableTransform, MapTransform):
class Rand2DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`.
"""
Expand Down Expand Up @@ -822,6 +822,17 @@ def randomize(self, spatial_size: Sequence[int]) -> None:
super().randomize(None)
self.rand_2d_elastic.randomize(spatial_size)

@staticmethod
def cpg_to_dvf(cpg, spacing, output_shape):
grid = torch.nn.functional.interpolate(
recompute_scale_factor=True,
input=cpg.unsqueeze(0),
scale_factor=ensure_tuple_rep(spacing, 2),
mode=InterpolateMode.BICUBIC.value,
align_corners=False,
)
return CenterSpatialCrop(roi_size=output_shape)(grid[0])

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
Expand All @@ -831,25 +842,64 @@ def __call__(
self.randomize(spatial_size=sp_size)

if self._do_transform:
grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
grid = torch.nn.functional.interpolate( # type: ignore
recompute_scale_factor=True,
input=grid.unsqueeze(0),
scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2),
mode=InterpolateMode.BICUBIC.value,
align_corners=False,
)
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
cpg_w_affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg)
affine = self.rand_2d_elastic.rand_affine_grid.get_transformation_matrix()
grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size)
extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)}
else:
grid = create_grid(spatial_size=sp_size)
extra_info = None

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
self.push_transform(d, key, extra_info=extra_info)
d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
# This variable will be `not None` if vtk or sitk is present
inv_def_no_affine = None

for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
if transform[InverseKeys.DO_TRANSFORM.value]:
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Only need to calculate inverse deformation once as it is the same for all keys
if idx == 0:
# If magnitude == 0, then non-rigid component is identity -- so just create blank
if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0):
inv_def_no_affine = create_grid(spatial_size=orig_size)
else:
fwd_cpg_no_affine = transform[InverseKeys.EXTRA_INFO.value]["cpg"]
fwd_def_no_affine = self.cpg_to_dvf(
fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size
)
inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine)
# if inverse did not succeed (sitk or vtk present), data will not be changed.
if inv_def_no_affine is not None:
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)
inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)(
grid=inv_def_no_affine
)
# Back to original size
inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore
# Apply inverse transform
if inv_def_no_affine is not None:
out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode)
d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out

else:
d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class Rand3DElasticd(RandomizableTransform, MapTransform):
class Rand3DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`.
"""
Expand Down Expand Up @@ -949,17 +999,55 @@ def __call__(
sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:])

self.randomize(grid_size=sp_size)
grid = create_grid(spatial_size=sp_size)
grid_no_affine = create_grid(spatial_size=sp_size)
affine = np.eye(4)
if self._do_transform:
device = self.rand_3d_elastic.device
grid = torch.tensor(grid).to(device)
grid_no_affine = torch.tensor(grid_no_affine).to(device)
gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device)
offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0)
grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)
grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
grid_w_affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine)
affine = self.rand_3d_elastic.rand_affine_grid.get_transformation_matrix()
else:
grid_w_affine = grid_no_affine
affine = np.eye(len(sp_size) + 1)

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
self.push_transform(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine})
d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))

for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
if transform[InverseKeys.DO_TRANSFORM.value]:
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Only need to calculate inverse deformation once as it is the same for all keys
if idx == 0:
fwd_def_no_affine = transform[InverseKeys.EXTRA_INFO.value]["grid_no_affine"]
inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine)
# if inverse did not succeed (sitk or vtk present), data will not be changed.
if inv_def_no_affine is not None:
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)
inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)(
grid=inv_def_no_affine
)
# Back to original size
inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore
# Apply inverse transform
if inv_def_w_affine is not None:
out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode)
d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out
else:
d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.0
cucim==0.18.1
openslide-python==1.1.2
vtk
Loading

0 comments on commit 9e65010

Please sign in to comment.