diff --git a/docs/source/losses.rst b/docs/source/losses.rst index dfd8ce2ddb..0ce7185f15 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -68,6 +68,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`ShapeLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: ShapeLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 1922996fb6..35709c6621 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -27,5 +27,6 @@ from .focal_loss import FocalLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss +from .shape import ShapeDistLoss from .spatial_mask import MaskedLoss from .tversky import TverskyLoss diff --git a/monai/losses/shape.py b/monai/losses/shape.py new file mode 100644 index 0000000000..e545127515 --- /dev/null +++ b/monai/losses/shape.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from itertools import product +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks import one_hot +from monai.utils import LossReduction, convert_data_type, optional_import + +distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt") + + +class ShapeDistLoss(_Loss): + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Optional[Callable] = None, + smooth_nr: float = 1e-8, + smooth_k: float = 2e-1, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Shape Information Loss inspired by [Huang et al., 2021](https://ieeexplore.ieee.org/document/9433775) + + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute + other activation layers, Defaults to ``None``. for example: + `other_act = torch.tanh`. + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_k: smoothness factor used in the Heaviside function + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + """ + super().__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.smooth_nr = float(smooth_nr) + self.smooth_k = float(smooth_k) + + def distance_map(self, mask: torch.Tensor) -> torch.Tensor: + """Creates a distance map from a 2D binary mask. + + Args: + mask (torch.Tensor): Binary mask of shape [WD]. + + Returns: + torch.Tensor: Boundary distance map of shape [WD] + """ + # Convert to NumPy to use with SciPy + roi, _, _ = convert_data_type(mask, np.ndarray) + + # Compute normalized distance transform + dt: np.ndarray = distance_transform_edt(roi) + dt /= dt.max() + self.smooth_nr + + # apply Heaviside function to softly normalize into [0, 1] + result: np.ndarray = 1 / (1 + np.exp(-(1 - dt) / self.smooth_k)) + # mask using region of interest + result *= roi + + return torch.Tensor(result) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + + Raises: + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> from monai.losses.dice import * # NOQA + >>> import torch + >>> from monai.losses.dice import ShapeLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = ShapeLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape + """ + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + distance_maps = torch.empty(size=target.size()).to(input.device) + + for im, ch in product(*map(range, target.shape[:2])): + distance_maps[im, ch] = self.distance_map(target[im, ch]) + + f = (distance_maps - input).abs().sum(dim=(2, 3)) / input.sum(dim=(2, 3)) + + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return f