Skip to content

Commit

Permalink
Restoring scribbles using torchmaxflow (Project-MONAI#731)
Browse files Browse the repository at this point in the history
* Fix preload config (Project-MONAI#728)

* Fix preload config

Signed-off-by: SACHIDANAND ALLE <[email protected]>

* Fix preload config

Signed-off-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* restoring scribbles with torchmaxflow

Signed-off-by: masadcv <[email protected]>

* update to torchmaxflow 0.0.4rc2

Signed-off-by: masadcv <[email protected]>

* fix torch import issue

Signed-off-by: masadcv <[email protected]>

* fix scribbles label issue for roi and histogram tx

Signed-off-by: masadcv <[email protected]>

* increase complexity of histogram to handle difficult cases

Signed-off-by: masadcv <[email protected]>

* Add spatial size argment to infer file (Project-MONAI#730)

* Add spatial size argment to infer file

Signed-off-by: Andres Diaz-Pinto <[email protected]>

* Update segmentation App

Signed-off-by: Andres Diaz-Pinto <[email protected]>

* Update segmentation App - json arg

Signed-off-by: Andres Diaz-Pinto <[email protected]>
Signed-off-by: masadcv <[email protected]>

* no collapse scribbles on nextsamp, if user is scribbling

Signed-off-by: masadcv <[email protected]>

* update to torchmaxflow 0.0.5

Signed-off-by: masadcv <[email protected]>

* drop python 3.6 support (Project-MONAI#735)

Monai has dropped Python 3.6 support (Project-MONAI/MONAI@e655b4e). PyTorch dropped Python 3.6 starting in version 1.11.0 (pytorch/pytorch@dc5cda0).

Python 3.6 official end of support date was 23rd Dec 2021.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* update copyright headers dropping specific year (Project-MONAI#737)

This is based on the same changes contributed to the Monai toolkit repo in Project-MONAI/MONAI@1516ca7.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* update PY_REQUIRED_MINOR to reflect python 3.7 minimum (Project-MONAI#738)

* update PY_REQUIRED_MINOR to reflect python 3.7 minimum

This should have been originally included with Project-MONAI@92e0fec.

Signed-off-by: James Butler <[email protected]>

* drop torch 1.5 support

Monai dropped torch 1.5 support in Project-MONAI/MONAI@2e83cd2.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Linting with pre-commit ci (Project-MONAI#736)

* Fix Flake8 E501: line too long

https://www.flake8rules.com/rules/E501.html

Signed-off-by: James Butler <[email protected]>

* Fix Flake8 E711: Comparison to none should be 'if cond is none:'

https://www.flake8rules.com/rules/E711.html

Signed-off-by: James Butler <[email protected]>

* Fix Flake8 E741: Do not use variables named 'l', 'o', or 'i'

https://www.flake8rules.com/rules/E741.html

Signed-off-by: James Butler <[email protected]>

* Fix Flake8 F841: Local variable name is assigned to but never used

https://www.flake8rules.com/rules/F841.html

Signed-off-by: James Butler <[email protected]>

* Consolidate lint checks to cross platform pre-commit framework

This is to be paired with GitHub application pre-commit ci https://github.com/marketplace/pre-commit-ci which is a continuous integration service for the pre-commit framework.
https://github.com/pre-commit/action is the deprecated GitHub actions version.

Signed-off-by: James Butler <[email protected]>

* trim trailing whitespace

Signed-off-by: James Butler <[email protected]>

* Upgrade python syntax to 3.7 and newer

monailabel currently has the requirement python_requires = >= 3.7

Signed-off-by: James Butler <[email protected]>

* Update CI to latest version of "action/checkout" GitHub actions

See https://github.com/actions/checkout/releases/tag/v3.0.0

Signed-off-by: James Butler <[email protected]>

* Update CI to latest version of "action/setup-python" GitHub actions

See https://github.com/actions/setup-python/releases/tag/v3.0.0

Signed-off-by: James Butler <[email protected]>

* Add PR testing ci on python 3.9

Monai uses python 3.8 and 3D Slicer uses Python 3.9.

Signed-off-by: James Butler <[email protected]>

* Fix simpleitk building whl from source in CI

Signed-off-by: SACHIDANAND ALLE <[email protected]>

* Add MyPy and fix Azure Pipeline

Signed-off-by: SACHIDANAND ALLE <[email protected]>

* Remove mypy from runtest.sh and runtests.bat not needed anymore

Signed-off-by: SACHIDANAND ALLE <[email protected]>

Co-authored-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Revert "Linting with pre-commit ci (Project-MONAI#736)"

This reverts commit ba9e89e so that changes can be organized in separate commits and re-committed.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Fix Flake8 E501: line too long

https://www.flake8rules.com/rules/E501.html

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Fix Flake8 E711: Comparison to none should be 'if cond is none:'

https://www.flake8rules.com/rules/E711.html

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Fix Flake8 E741: Do not use variables named 'l', 'o', or 'i'

https://www.flake8rules.com/rules/E741.html

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Fix Flake8 F841: Local variable name is assigned to but never used

https://www.flake8rules.com/rules/F841.html

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Consolidate lint checks to cross platform pre-commit framework

This is to be paired with GitHub application pre-commit ci https://github.com/marketplace/pre-commit-ci which is a continuous integration service for the pre-commit framework.
https://github.com/pre-commit/action is the deprecated GitHub actions version.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* trim trailing whitespace

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Upgrade python syntax to 3.7 and newer

monailabel currently has the requirement python_requires = >= 3.7

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Update CI to latest version of "action/checkout" GitHub actions

See https://github.com/actions/checkout/releases/tag/v3.0.0

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Update CI to latest version of "action/setup-python" GitHub actions

See https://github.com/actions/setup-python/releases/tag/v3.0.0

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Add PR testing ci on python 3.9

Monai uses python 3.8 and 3D Slicer uses Python 3.9.

Signed-off-by: James Butler <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Fix simpleitk building whl from source in CI

Signed-off-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Add MyPy and fix Azure Pipeline

Signed-off-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* Remove mypy from runtest.sh and runtests.bat not needed anymore

Signed-off-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: masadcv <[email protected]>

* fix module name for pre-commit (Project-MONAI#742)

Signed-off-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: masadcv <[email protected]>

* adds codeql (Project-MONAI#741)

* adds codeql

Signed-off-by: Wenqi Li <[email protected]>

* temp test

Signed-off-by: Wenqi Li <[email protected]>

* python tests

Signed-off-by: Wenqi Li <[email protected]>

* fix log filename - security fix

Signed-off-by: SACHIDANAND ALLE <[email protected]>

Co-authored-by: SACHIDANAND ALLE <[email protected]>
Signed-off-by: masadcv <[email protected]>

* update to torchmaxflow 0.0.6rc1

Signed-off-by: masadcv <[email protected]>

* update to torchmaxflow 0.0.6rc1

Signed-off-by: masadcv <[email protected]>

* adding multilabel support + dynamic scribbles label support

Signed-off-by: masadcv <[email protected]>

Co-authored-by: SACHIDANAND ALLE <[email protected]>
Co-authored-by: Andres Diaz-Pinto <[email protected]>
Co-authored-by: James Butler <[email protected]>
Co-authored-by: James Butler <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
7 people authored Apr 13, 2022
1 parent 139af6b commit bf25465
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 81 deletions.
130 changes: 130 additions & 0 deletions monailabel/scribbles/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, ScaleIntensityRanged, Spacingd

from monailabel.interfaces.tasks.infer import InferTask, InferType
from monailabel.scribbles.transforms import (
AddBackgroundScribblesFromROId,
ApplyGraphCutOptimisationd,
MakeISegUnaryd,
MakeLikelihoodFromScribblesHistogramd,
)
from monailabel.transform.post import BoundingBoxd, Restored


class HistogramBasedGraphCut(InferTask):
"""
Defines histogram-based GraphCut task for Generic segmentation from the following paper:
Wang, Guotai, et al. "Interactive medical image segmentation using deep learning with image-specific fine tuning."
IEEE transactions on medical imaging 37.7 (2018): 1562-1573. (preprint: https://arxiv.org/pdf/1710.04043.pdf)
This task takes as input 1) original image volume and 2) scribbles from user
indicating foreground and background regions. A likelihood volume is generated using histogram method.
User-scribbles are incorporated using Equation 7 on page 4 of the paper.
SimpleCRF's GraphCut layer is used to optimise Equation 5 from the paper, where unaries come from Equation 7
and pairwise is the original input volume.
"""

def __init__(
self,
dimension=3,
description="A post processing step with histogram-based GraphCut for Generic segmentation",
intensity_range=(-300, 200, 0.0, 1.0, True),
pix_dim=(2.5, 2.5, 5.0),
lamda=1.0,
sigma=0.1,
labels=None,
config=None,
):
if config:
config.update({"lamda": lamda, "sigma": sigma})
else:
config = {"lamda": lamda, "sigma": sigma}
super().__init__(
path=None,
network=None,
labels=labels,
type=InferType.SCRIBBLES,
dimension=dimension,
description=description,
config=config,
)
self.intensity_range = intensity_range
self.pix_dim = pix_dim
self.lamda = lamda
self.sigma = sigma

# set default scribbles labels
self.scribbles_bg_label = 2 if not self.labels else len(self.labels) + 1
self.scribbles_fg_label = 3 if not self.labels else len(self.labels) + 2

def pre_transforms(self, data):
return [
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
AddBackgroundScribblesFromROId(
scribbles="label",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
),
# at the moment optimisers are bottleneck taking a long time,
# therefore scaling non-isotropic with big spacing
Spacingd(keys=["image", "label"], pixdim=self.pix_dim, mode=["bilinear", "nearest"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys="image",
a_min=self.intensity_range[0],
a_max=self.intensity_range[1],
b_min=self.intensity_range[2],
b_max=self.intensity_range[3],
clip=self.intensity_range[4],
),
MakeLikelihoodFromScribblesHistogramd(
image="image",
scribbles="label",
post_proc_label="prob",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
normalise=True,
),
]

def inferer(self, data):
return Compose(
[
# unary term maker
MakeISegUnaryd(
image="image",
logits="prob",
scribbles="label",
unary="unary",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
),
# optimiser
ApplyGraphCutOptimisationd(
unary="unary",
pairwise="image",
post_proc_label="pred",
lamda=self.lamda,
sigma=self.sigma,
),
]
)

def post_transforms(self, data):
return [
Restored(keys="pred", ref_image="image"),
BoundingBoxd(keys="pred", result="result", bbox="bbox"),
]
107 changes: 106 additions & 1 deletion monailabel/scribbles/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monailabel.transform.writer import Writer

from .utils import make_iseg_unary, make_likelihood_image_histogram
from .utils import make_iseg_unary, make_likelihood_image_histogram, maxflow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,6 +64,19 @@ def _copy_affine(self, d, src, dst):

return d

def _set_scribbles_idx_from_labelinfo(self, d):
label_info = d.get("label_info", [])
for lb in label_info:
if lb.get("name", None) == "background_scribbles":
id = lb.get("id", self.scribbles_bg_label)
self.scribbles_bg_label = id
logging.info("Loading background scribbles labels from: {} with index: {}".format(lb.get("name"), id))

if lb.get("name", None) == "foreground_scribbles":
id = lb.get("id", self.scribbles_fg_label)
self.scribbles_fg_label = id
logging.info("Loading foreground scribbles labels from: {} with index: {}".format(lb.get("name"), id))


#######################################
#######################################
Expand All @@ -89,6 +102,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# read relevant terms from data
scribbles = self._fetch_data(d, self.scribbles)

Expand Down Expand Up @@ -158,6 +174,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# copy affine meta data from image input
d = self._copy_affine(d, src=self.image, dst=self.post_proc_label)

Expand Down Expand Up @@ -277,6 +296,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# copy affine meta data from image input
self._copy_affine(d, self.image, self.unary)

Expand Down Expand Up @@ -311,6 +333,89 @@ def __call__(self, data):
#######################
# Optimiser Transforms
#######################
class ApplyGraphCutOptimisationd(InteractiveSegmentationTransform):
"""
Generic GraphCut optimisation transform.
This can be used in conjuction with any Make*Unaryd transform
(e.g. MakeISegUnaryd from above for implementing ISeg unary term).
It optimises a typical energy function for interactive segmentation methods using SimpleCRF's GraphCut method,
e.g. Equation 5 from https://arxiv.org/pdf/1710.04043.pdf.
Usage Example::
Compose(
[
# unary term maker
MakeISegUnaryd(
image="image",
logits="logits",
scribbles="label",
unary="unary",
scribbles_bg_label=2,
scribbles_fg_label=3,
),
# optimiser
ApplyGraphCutOptimisationd(
unary="unary",
pairwise="image",
post_proc_label="pred",
lamda=10.0,
sigma=15.0,
),
]
)
"""

def __init__(
self,
unary: str,
pairwise: str,
meta_key_postfix: str = "meta_dict",
post_proc_label: str = "pred",
lamda: float = 8.0,
sigma: float = 0.1,
) -> None:
super().__init__(meta_key_postfix)
self.unary = unary
self.pairwise = pairwise
self.post_proc_label = post_proc_label
self.lamda = lamda
self.sigma = sigma

def __call__(self, data):
d = dict(data)

# attempt to fetch algorithmic parameters from app if present
self.lamda = d.get("lamda", self.lamda)
self.sigma = d.get("sigma", self.sigma)

# copy affine meta data from pairwise input
self._copy_affine(d, self.pairwise, self.post_proc_label)

# read relevant terms from data
unary_term = self._fetch_data(d, self.unary)
pairwise_term = self._fetch_data(d, self.pairwise)

# check if input unary is compatible with GraphCut opt
if unary_term.shape[0] > 2:
raise ValueError(f"GraphCut can only be applied to binary probabilities, received {unary_term.shape[0]}")

# # attempt to unfold probability term
# unary_term = self._unfold_prob(unary_term, axis=0)

# prepare data for SimpleCRF's GraphCut
unary_term = torch.from_numpy(unary_term).unsqueeze(0)
pairwise_term = torch.from_numpy(pairwise_term).unsqueeze(0)

# run GraphCut
post_proc_label = maxflow(pairwise_term, unary_term, lamda=self.lamda, sigma=self.sigma).squeeze(0).numpy()

d[self.post_proc_label] = post_proc_label

return d


class ApplyCRFOptimisationd(InteractiveSegmentationTransform):
"""
Generic MONAI CRF optimisation transform.
Expand Down
13 changes: 12 additions & 1 deletion monailabel/scribbles/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import logging

import numpy as np
from monai.utils import optional_import

# torch import is needed to execute torchmaxflow
optional_import("torch")
import torchmaxflow

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +24,12 @@ def get_eps(data):
return np.finfo(data.dtype).eps


def maxflow(image, prob, lamda=5, sigma=0.1):
# lamda: weight of smoothing term
# sigma: std of intensity values
return torchmaxflow.maxflow(image, prob, lamda, sigma)


def make_iseg_unary(
prob,
scribbles,
Expand Down Expand Up @@ -139,7 +150,7 @@ def make_likelihood_image_histogram(image, scrib, scribbles_bg_label, scribbles_

# generate histograms for background/foreground
bg_hist, fg_hist, bin_edges = make_histograms(
image, scrib, scribbles_bg_label, scribbles_fg_label, alpha_bg=1, alpha_fg=1, bins=32
image, scrib, scribbles_bg_label, scribbles_fg_label, alpha_bg=1, alpha_fg=1, bins=64
)

# lookup values for each voxel for generating background/foreground probabilities
Expand Down
Loading

0 comments on commit bf25465

Please sign in to comment.