Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of boundary-aware ShapeLoss #4205

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ Segmentation Losses
.. autoclass:: ContrastiveLoss
:members:

`ShapeLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: ShapeLoss
:members:

Registration Losses
-------------------

Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
from .focal_loss import FocalLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .shape import ShapeLoss
from .spatial_mask import MaskedLoss
from .tversky import TverskyLoss
172 changes: 172 additions & 0 deletions monai/losses/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from itertools import product
from typing import Callable, Optional, Union

import numpy as np
import torch
from scipy import ndimage
from torch.nn.modules.loss import _Loss

from monai.networks import one_hot
from monai.utils import LossReduction


class ShapeLoss(_Loss):
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Optional[Callable] = None,
smooth_nr: float = 1e-8,
smooth_k: float = 2e-1,
reduction: Union[LossReduction, str] = LossReduction.MEAN,
) -> None:
"""
Shape Information Loss inspired by [Huang et al., 2021](https://ieeexplore.ieee.org/document/9433775)

Args:
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
other activation layers, Defaults to ``None``. for example:
`other_act = torch.tanh`.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_k: smoothness factor used in the Heaviside function
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
"""
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.smooth_nr = float(smooth_nr)
self.smooth_k = float(smooth_k)

def distance_map(self, mask: torch.Tensor) -> torch.Tensor:
"""Creates a distance map from a 2D binary mask.

Args:
mask (torch.Tensor): Binary mask of shape [WD].

Returns:
torch.Tensor: Boundary distance map of shape [WD]
"""
# Convert to NumPy to use with SciPy
roi: np.ndarray = mask.numpy()

# Compute normalized distance transform
dt: np.ndarray = ndimage.distance_transform_edt(roi)
dt /= dt.max() + self.smooth_nr

# apply Heaviside function to softly normalize into [0, 1]
result: np.ndarray = 1 / (1 + np.exp(-(1 - dt) / self.smooth_k))
# mask using region of interest
result *= roi

return torch.Tensor(result)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

Example:
>>> from monai.losses.dice import * # NOQA
>>> import torch
>>> from monai.losses.dice import ShapeLoss
>>> B, C, H, W = 7, 5, 3, 2
>>> input = torch.rand(B, C, H, W)
>>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
>>> target = one_hot(target_idx[:, None, ...], num_classes=C)
>>> self = ShapeLoss(reduction='none')
>>> loss = self(input, target)
>>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to create a test case for this class, you can follow this one https://github.com/Project-MONAI/MONAI/blob/dev/tests/test_local_normalized_cross_correlation_loss.py

you can run a single test case within the codebase:

pip install -U -r requirements-dev.txt  # install testing tools
python -m tests.test_local_normalized_cross_correlation_loss

"""
if self.sigmoid:
input = torch.sigmoid(input)

n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
else:
input = torch.softmax(input, 1)

if self.other_act is not None:
input = self.other_act(input)

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

distance_maps = torch.empty(size=target.size()).to(input.device)

for im, ch in product(*map(range, target.shape[:2])):
distance_maps[im, ch] = self.distance_map(target[im, ch])

f = (distance_maps - input).abs().sum(dim=(2, 3)) / input.sum(dim=(2, 3))

if self.reduction == LossReduction.MEAN.value:
f = torch.mean(f) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
f = torch.sum(f) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
f = f.view(broadcast_shape)
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return f