Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring scribbles using torchmaxflow #731

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8b7f869
Fix preload config (#728)
SachidanandAlle Apr 6, 2022
ee72604
restoring scribbles with torchmaxflow
masadcv Apr 6, 2022
f413bfc
update to torchmaxflow 0.0.4rc2
masadcv Apr 7, 2022
9caaf9b
fix torch import issue
masadcv Apr 7, 2022
bfdad6b
fix scribbles label issue for roi and histogram tx
masadcv Apr 7, 2022
f5f34b1
increase complexity of histogram to handle difficult cases
masadcv Apr 8, 2022
22851cb
Add spatial size argment to infer file (#730)
diazandr3s Apr 7, 2022
d420c8e
no collapse scribbles on nextsamp, if user is scribbling
masadcv Apr 8, 2022
717023c
update to torchmaxflow 0.0.5
masadcv Apr 8, 2022
828eca1
drop python 3.6 support (#735)
jamesobutler Apr 10, 2022
f98d4dc
update copyright headers dropping specific year (#737)
jamesobutler Apr 10, 2022
9b8d2a6
update PY_REQUIRED_MINOR to reflect python 3.7 minimum (#738)
jamesobutler Apr 11, 2022
628b806
Linting with pre-commit ci (#736)
jamesobutler Apr 11, 2022
5d6bdea
Revert "Linting with pre-commit ci (#736)"
jamesobutler Apr 11, 2022
59f9665
Fix Flake8 E501: line too long
jamesobutler Apr 9, 2022
177c81a
Fix Flake8 E711: Comparison to none should be 'if cond is none:'
jamesobutler Apr 9, 2022
6f9dbb1
Fix Flake8 E741: Do not use variables named 'l', 'o', or 'i'
jamesobutler Apr 9, 2022
7509102
Fix Flake8 F841: Local variable name is assigned to but never used
jamesobutler Apr 9, 2022
e41aa37
Consolidate lint checks to cross platform pre-commit framework
jamesobutler Apr 9, 2022
00618f6
trim trailing whitespace
jamesobutler Apr 11, 2022
683a5aa
Upgrade python syntax to 3.7 and newer
jamesobutler Apr 9, 2022
3512847
Update CI to latest version of "action/checkout" GitHub actions
jamesobutler Apr 9, 2022
5b76278
Update CI to latest version of "action/setup-python" GitHub actions
jamesobutler Apr 9, 2022
1474afc
Add PR testing ci on python 3.9
jamesobutler Apr 10, 2022
799bcde
Fix simpleitk building whl from source in CI
SachidanandAlle Apr 11, 2022
a769d8f
Add MyPy and fix Azure Pipeline
SachidanandAlle Apr 11, 2022
cd408d8
Remove mypy from runtest.sh and runtests.bat not needed anymore
SachidanandAlle Apr 11, 2022
434c30a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2022
1e39a08
fix module name for pre-commit (#742)
SachidanandAlle Apr 11, 2022
e829bbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2022
b884669
adds codeql (#741)
wyli Apr 11, 2022
3da200d
update to torchmaxflow 0.0.6rc1
masadcv Apr 12, 2022
a9c4f82
update to torchmaxflow 0.0.6rc1
masadcv Apr 12, 2022
e22b3c0
Merge branch 'main' into restore-scribbles-with-torchmaxflow
masadcv Apr 12, 2022
80a92a1
Merge branch 'main' into restore-scribbles-with-torchmaxflow
masadcv Apr 12, 2022
a5b3cc5
adding multilabel support + dynamic scribbles label support
masadcv Apr 13, 2022
d7e763e
Merge branch 'main' into restore-scribbles-with-torchmaxflow
SachidanandAlle Apr 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions monailabel/scribbles/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, ScaleIntensityRanged, Spacingd

from monailabel.interfaces.tasks.infer import InferTask, InferType
from monailabel.scribbles.transforms import (
AddBackgroundScribblesFromROId,
ApplyGraphCutOptimisationd,
MakeISegUnaryd,
MakeLikelihoodFromScribblesHistogramd,
)
from monailabel.transform.post import BoundingBoxd, Restored


class HistogramBasedGraphCut(InferTask):
"""
Defines histogram-based GraphCut task for Generic segmentation from the following paper:

Wang, Guotai, et al. "Interactive medical image segmentation using deep learning with image-specific fine tuning."
IEEE transactions on medical imaging 37.7 (2018): 1562-1573. (preprint: https://arxiv.org/pdf/1710.04043.pdf)

This task takes as input 1) original image volume and 2) scribbles from user
indicating foreground and background regions. A likelihood volume is generated using histogram method.
User-scribbles are incorporated using Equation 7 on page 4 of the paper.

SimpleCRF's GraphCut layer is used to optimise Equation 5 from the paper, where unaries come from Equation 7
and pairwise is the original input volume.
"""

def __init__(
self,
dimension=3,
description="A post processing step with histogram-based GraphCut for Generic segmentation",
intensity_range=(-300, 200, 0.0, 1.0, True),
pix_dim=(2.5, 2.5, 5.0),
lamda=1.0,
sigma=0.1,
labels=None,
config=None,
):
if config:
config.update({"lamda": lamda, "sigma": sigma})
else:
config = {"lamda": lamda, "sigma": sigma}
super().__init__(
path=None,
network=None,
labels=labels,
type=InferType.SCRIBBLES,
dimension=dimension,
description=description,
config=config,
)
self.intensity_range = intensity_range
self.pix_dim = pix_dim
self.lamda = lamda
self.sigma = sigma

# set default scribbles labels
self.scribbles_bg_label = 2 if not self.labels else len(self.labels) + 1
self.scribbles_fg_label = 3 if not self.labels else len(self.labels) + 2

def pre_transforms(self, data):
return [
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
AddBackgroundScribblesFromROId(
scribbles="label",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
),
# at the moment optimisers are bottleneck taking a long time,
# therefore scaling non-isotropic with big spacing
Spacingd(keys=["image", "label"], pixdim=self.pix_dim, mode=["bilinear", "nearest"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys="image",
a_min=self.intensity_range[0],
a_max=self.intensity_range[1],
b_min=self.intensity_range[2],
b_max=self.intensity_range[3],
clip=self.intensity_range[4],
),
MakeLikelihoodFromScribblesHistogramd(
image="image",
scribbles="label",
post_proc_label="prob",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
normalise=True,
),
]

def inferer(self, data):
return Compose(
[
# unary term maker
MakeISegUnaryd(
image="image",
logits="prob",
scribbles="label",
unary="unary",
scribbles_bg_label=self.scribbles_bg_label,
scribbles_fg_label=self.scribbles_fg_label,
),
# optimiser
ApplyGraphCutOptimisationd(
unary="unary",
pairwise="image",
post_proc_label="pred",
lamda=self.lamda,
sigma=self.sigma,
),
]
)

def post_transforms(self, data):
return [
Restored(keys="pred", ref_image="image"),
BoundingBoxd(keys="pred", result="result", bbox="bbox"),
]
107 changes: 106 additions & 1 deletion monailabel/scribbles/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monailabel.transform.writer import Writer

from .utils import make_iseg_unary, make_likelihood_image_histogram
from .utils import make_iseg_unary, make_likelihood_image_histogram, maxflow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,6 +64,19 @@ def _copy_affine(self, d, src, dst):

return d

def _set_scribbles_idx_from_labelinfo(self, d):
label_info = d.get("label_info", [])
for lb in label_info:
if lb.get("name", None) == "background_scribbles":
id = lb.get("id", self.scribbles_bg_label)
self.scribbles_bg_label = id
logging.info("Loading background scribbles labels from: {} with index: {}".format(lb.get("name"), id))

if lb.get("name", None) == "foreground_scribbles":
id = lb.get("id", self.scribbles_fg_label)
self.scribbles_fg_label = id
logging.info("Loading foreground scribbles labels from: {} with index: {}".format(lb.get("name"), id))


#######################################
#######################################
Expand All @@ -89,6 +102,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# read relevant terms from data
scribbles = self._fetch_data(d, self.scribbles)

Expand Down Expand Up @@ -158,6 +174,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# copy affine meta data from image input
d = self._copy_affine(d, src=self.image, dst=self.post_proc_label)

Expand Down Expand Up @@ -277,6 +296,9 @@ def __init__(
def __call__(self, data):
d = dict(data)

# load scribbles idx from labels_info (if available)
self._set_scribbles_idx_from_labelinfo(d)

# copy affine meta data from image input
self._copy_affine(d, self.image, self.unary)

Expand Down Expand Up @@ -311,6 +333,89 @@ def __call__(self, data):
#######################
# Optimiser Transforms
#######################
class ApplyGraphCutOptimisationd(InteractiveSegmentationTransform):
"""
Generic GraphCut optimisation transform.

This can be used in conjuction with any Make*Unaryd transform
(e.g. MakeISegUnaryd from above for implementing ISeg unary term).
It optimises a typical energy function for interactive segmentation methods using SimpleCRF's GraphCut method,
e.g. Equation 5 from https://arxiv.org/pdf/1710.04043.pdf.

Usage Example::

Compose(
[
# unary term maker
MakeISegUnaryd(
image="image",
logits="logits",
scribbles="label",
unary="unary",
scribbles_bg_label=2,
scribbles_fg_label=3,
),
# optimiser
ApplyGraphCutOptimisationd(
unary="unary",
pairwise="image",
post_proc_label="pred",
lamda=10.0,
sigma=15.0,
),
]
)
"""

def __init__(
self,
unary: str,
pairwise: str,
meta_key_postfix: str = "meta_dict",
post_proc_label: str = "pred",
lamda: float = 8.0,
sigma: float = 0.1,
) -> None:
super().__init__(meta_key_postfix)
self.unary = unary
self.pairwise = pairwise
self.post_proc_label = post_proc_label
self.lamda = lamda
self.sigma = sigma

def __call__(self, data):
d = dict(data)

# attempt to fetch algorithmic parameters from app if present
self.lamda = d.get("lamda", self.lamda)
self.sigma = d.get("sigma", self.sigma)

# copy affine meta data from pairwise input
self._copy_affine(d, self.pairwise, self.post_proc_label)

# read relevant terms from data
unary_term = self._fetch_data(d, self.unary)
pairwise_term = self._fetch_data(d, self.pairwise)

# check if input unary is compatible with GraphCut opt
if unary_term.shape[0] > 2:
raise ValueError(f"GraphCut can only be applied to binary probabilities, received {unary_term.shape[0]}")

# # attempt to unfold probability term
# unary_term = self._unfold_prob(unary_term, axis=0)

# prepare data for SimpleCRF's GraphCut
unary_term = torch.from_numpy(unary_term).unsqueeze(0)
pairwise_term = torch.from_numpy(pairwise_term).unsqueeze(0)

# run GraphCut
post_proc_label = maxflow(pairwise_term, unary_term, lamda=self.lamda, sigma=self.sigma).squeeze(0).numpy()

d[self.post_proc_label] = post_proc_label

return d


class ApplyCRFOptimisationd(InteractiveSegmentationTransform):
"""
Generic MONAI CRF optimisation transform.
Expand Down
13 changes: 12 additions & 1 deletion monailabel/scribbles/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import logging

import numpy as np
from monai.utils import optional_import

# torch import is needed to execute torchmaxflow
optional_import("torch")
import torchmaxflow

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +24,12 @@ def get_eps(data):
return np.finfo(data.dtype).eps


def maxflow(image, prob, lamda=5, sigma=0.1):
# lamda: weight of smoothing term
# sigma: std of intensity values
return torchmaxflow.maxflow(image, prob, lamda, sigma)


def make_iseg_unary(
prob,
scribbles,
Expand Down Expand Up @@ -139,7 +150,7 @@ def make_likelihood_image_histogram(image, scrib, scribbles_bg_label, scribbles_

# generate histograms for background/foreground
bg_hist, fg_hist, bin_edges = make_histograms(
image, scrib, scribbles_bg_label, scribbles_fg_label, alpha_bg=1, alpha_fg=1, bins=32
image, scrib, scribbles_bg_label, scribbles_fg_label, alpha_bg=1, alpha_fg=1, bins=64
)

# lookup values for each voxel for generating background/foreground probabilities
Expand Down
Loading