Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Extension of PR#105 #112

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
482 changes: 482 additions & 0 deletions notebooks/modelsGenesis_in_dMRI.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions rising/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Spatial Transforms
* Tensor Transforms
* Utility Transforms
* Painting Transforms
"""

from rising.transforms.abstract import *
Expand All @@ -29,3 +30,4 @@
from rising.transforms.utility import *
from rising.transforms.tensor import *
from rising.transforms.affine import *
from rising.transforms.painting import *
1 change: 1 addition & 0 deletions rising/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from rising.transforms.functional.tensor import *
from rising.transforms.functional.utility import *
from rising.transforms.functional.channel import *
from rising.transforms.functional.painting import *
40 changes: 39 additions & 1 deletion rising/transforms/functional/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Union, Sequence, Optional

from rising.utils import check_scalar
from rising.utils.torchinterp1d import Interp1d

__all__ = ["norm_range", "norm_min_max", "norm_zero_mean_unit_std", "norm_mean_std",
"add_noise", "add_value", "gamma_correction", "scale_by_value", "clamp"]
"add_noise", "add_value", "gamma_correction", "scale_by_value", "clamp",
"bezier_3rd_order", "random_inversion"]


def clamp(data: torch.Tensor, min: float, max: float,
Expand Down Expand Up @@ -227,3 +229,39 @@ def scale_by_value(data: torch.Tensor, value: float,
torch.Tensor: augmented data
"""
return torch.mul(data, value, out=out)


def bezier_3rd_order(data: torch.Tensor, maxv: float=1.0, minv: float=0.0,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
p0 = torch.zeros((1,2))
p1 = torch.rand((1,2))
p2 = torch.rand((1,2))
p3 = torch.ones((1,2))

t = torch.linspace(0.0, 1.0, 1000).unsqueeze(1)

points = (1-t*t*t)*p0 + 3*(1-t)*(1-t)*t*p1 + 3*(1-t)*t*t*p2 + t*t*t*p3

# scaling according to maxv,minv
points = points*(maxv-minv) + minv

xvals = points[:,0]
yvals = points[:,1]

out_flat = Interp1d()(xvals, yvals, data.view(-1))

return out_flat.view(data.shape)


def random_inversion(data: torch.Tensor, prob_inversion: float=0.5,
maxv: float=1.0, minv: float=0.0,
out: Optional[torch.Tensor] = None) -> torch.Tensor:

if torch.rand((1)) < prob_inversion:
# Inversion of curve
out = maxv + minv - data
else:
# do nothing
out = data

return out
118 changes: 118 additions & 0 deletions rising/transforms/functional/painting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch

__all__ = ["local_pixel_shuffle", "random_inpainting", "random_outpainting"]


def local_pixel_shuffle(data: torch.Tensor, n: int = -1, block_size: tuple = (0, 0, 0), rel_block_size: float = 0.1) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size()

if n < 0:
n = int(1000 * channels) # changes ~ 12.5% of voxels
for b in range(batch_size):
for _ in range(n):
c = torch.randint(0, max(1, channels - 1), (1,))

(block_size_x, block_size_y, block_size_z) = (torch.tensor([size]) for size in block_size)

if rel_block_size > 0:
block_size_x = torch.randint(2, max(2, int(img_rows * rel_block_size)) + 1, (1,))
block_size_y = torch.randint(2, max(2, int(img_cols * rel_block_size)) + 1, (1,))
block_size_z = torch.randint(2, max(2, int(img_deps * rel_block_size)) + 1, (1,))

x = torch.randint(0, int(img_rows - block_size_x), (1,))
y = torch.randint(0, int(img_cols - block_size_y), (1,))
z = torch.randint(0, int(img_deps - block_size_z), (1,))

window = data[b, c, x:x + block_size_x,
y:y + block_size_y,
z:z + block_size_z,
]
idx = torch.randperm(window.numel())
window = window.view(-1)[idx].view(window.size())

data[b, c, x:x + block_size_x,
y:y + block_size_y,
z:z + block_size_z] = window

return data


def random_inpainting(data: torch.Tensor, n: int = 5, maxv: float = 1.0, minv: float = 0.0, max_size: tuple = (0, 0, 0), min_size: tuple = (0, 0, 0), rel_max_size: tuple = (0.25, 0.25, 0.25), rel_min_size: tuple = (0.1, 0.1, 0.1), min_border_distance: tuple = (3, 3, 3)) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size() # N,C,Z,X,Y

if all((rel_max >= rel_min > 0 for rel_min, rel_max in zip(rel_min_size, rel_max_size))):
min_x = int(rel_min_size[0] * img_rows)
max_x = min(img_rows - 2 * min_border_distance[0] - 1, int(rel_max_size[0] * img_rows))
min_y = int(rel_min_size[1] * img_cols)
max_y = min(img_cols - 2 * min_border_distance[1] - 1, int(rel_max_size[1] * img_cols))
min_z = int(rel_min_size[2] * img_deps)
max_z = min(img_deps - 2 * min_border_distance[2] - 1, int(rel_max_size[2] * img_deps))
elif all((max >= min > 0 for min, max in zip(min_size, max_size))):
min_x, max_x = min_size[0], max_size[0]
min_y, max_y = min_size[1], max_size[1]
min_z, max_z = min_size[2], max_size[2]
else:
raise ValueError(
f'random_inpainting was called with neither a valid absolut nor a valid relative min/max patch size combination. Received absolut min_size {min_size}, max_size {max_size}, and relative rel_min_size {rel_min_size}, rel_max_size {rel_max_size}')

while n > 0 and torch.rand((1)) < 0.95:
for b in range(batch_size):
block_size_x = torch.randint(min_x, max_x + 1, (1,))
block_size_y = torch.randint(min_y, max_y + 1, (1,))
block_size_z = torch.randint(min_z, max_z + 1, (1,))
x = torch.randint(min_border_distance[0], int(img_rows - block_size_x - min_border_distance[0]), (1,))
y = torch.randint(min_border_distance[1], int(img_cols - block_size_y - min_border_distance[1]), (1,))
z = torch.randint(min_border_distance[2], int(img_deps - block_size_z - min_border_distance[2]), (1,))

block = torch.rand((1, channels, block_size_x, block_size_y, block_size_z)) \
* (maxv - minv) + minv

data[b, :, x:x + block_size_x,
y:y + block_size_y,
z:z + block_size_z] = block

n = n - 1

return data


def random_outpainting(data: torch.Tensor, maxv: float = 1.0, minv: float = 0.0, max_size: tuple = (0, 0, 0), min_size: tuple = (0, 0, 0), rel_max_size: tuple = (6 / 7, 6 / 7, 6 / 7), rel_min_size: tuple = (5 / 7, 5 / 7, 5 / 7), min_border_distance=(3, 3, 3)) -> torch.Tensor:

batch_size, channels, img_rows, img_cols, img_deps = data.size()

if all((rel_max >= rel_min > 0 for rel_min, rel_max in zip(rel_min_size, rel_max_size))):
min_x = int(rel_min_size[0] * img_rows)
# min() is necessary to have guarantee y > x for torch.randint(x,y) calls
# lowest possible index for block start is min_border_distance[i], highest possible is img_rows - block_size - min_border_distance[i]. -> block_size < img_rows - 2 * min_border_distance
max_x = min(img_rows - 2 * min_border_distance[0] - 1, int(rel_max_size[0] * img_rows))
min_y = int(rel_min_size[1] * img_cols)
max_y = min(img_cols - 2 * min_border_distance[1] - 1, int(rel_max_size[1] * img_cols))
min_z = int(rel_min_size[2] * img_deps)
max_z = min(img_deps - 2 * min_border_distance[2] - 1, int(rel_max_size[2] * img_deps))

elif all((max >= min > 0 for min, max in zip(min_size, max_size))):
min_x, max_x = min_size[0], max_size[0]
min_y, max_y = min_size[1], max_size[1]
min_z, max_z = min_size[2], max_size[2]
else:
raise ValueError(
f'random_inpainting was called with neither a valid absolut nor a valid relative min/max patch size combination. Received absolut min_size {min_size}, max_size {max_size}, and relative rel_min_size {rel_min_size}, rel_max_size {rel_max_size}')

out = torch.rand(data.size()) * (maxv - minv) + minv

block_size_x = torch.randint(min_x, max_x + 1, (1,))
block_size_y = torch.randint(min_y, max_y + 1, (1,))
block_size_z = torch.randint(min_z, max_z + 1, (1,))
x = torch.randint(min_border_distance[0], int(img_rows - block_size_x - min_border_distance[0]), (1,))
y = torch.randint(min_border_distance[1], int(img_cols - block_size_y - min_border_distance[1]), (1,))
z = torch.randint(min_border_distance[2], int(img_deps - block_size_z - min_border_distance[2]), (1,))

out[:, :, x:x + block_size_x,
y:y + block_size_y,
z:z + block_size_z] = data[:, :, x:x + block_size_x,
y:y + block_size_y,
z:z + block_size_z]

return out
2 changes: 1 addition & 1 deletion rising/transforms/functional/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def mirror(data: torch.Tensor, dims: Union[int, Sequence[int]]) -> torch.Tensor:
"""
if check_scalar(dims):
dims = (dims,)
# batch and channel dims
# batch and channel dims
dims = [d + 2 for d in dims]
return data.flip(dims)

Expand Down
30 changes: 28 additions & 2 deletions rising/transforms/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
gamma_correction,
add_value,
scale_by_value,
clamp)
clamp,
bezier_3rd_order,
random_inversion,
)

from rising.random import AbstractParameter

__all__ = ["Clamp", "NormRange", "NormMinMax",
"NormZeroMeanUnitStd", "NormMeanStd", "Noise",
"GaussianNoise", "ExponentialNoise", "GammaCorrection",
"RandomValuePerChannel", "RandomAddValue", "RandomScaleValue"]
"RandomValuePerChannel", "RandomAddValue", "RandomScaleValue",
"RandomBezierTransform", "InvertAmplitude"]


class Clamp(BaseTransform):
Expand Down Expand Up @@ -303,3 +308,24 @@ def __init__(self, random_sampler: AbstractParameter,
"""
super().__init__(augment_fn=scale_by_value, random_sampler=random_sampler,
per_channel=per_channel, keys=keys, grad=grad, **kwargs)


class RandomBezierTransform(BaseTransform):
""" Apply a random 3rd order bezier spline to the intensity values,
as proposed in Models Genesis """

def __init__(self, maxv: float = 1.0, minv: float = 0.0, keys: Sequence = ('data',), **kwargs):

super().__init__(augment_fn=bezier_3rd_order, maxv=maxv, minv=minv, keys=keys, grad=False, **kwargs)


class InvertAmplitude(BaseTransform):
""" Inverts the amplitude with probability p according to the following formula:
out = maxv + minv - data
"""

def __init__(self, prob: float = 0.5, maxv: float = 1.0, minv: float = 0.0,
keys: Sequence = ('data',), **kwargs):

super().__init__(augment_fn=random_inversion, prob_inversion=prob, maxv=maxv, minv=minv,
keys=keys, grad=False, **kwargs)
140 changes: 140 additions & 0 deletions rising/transforms/painting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import torch
from typing import Sequence

from rising.transforms.abstract import AbstractTransform, BaseTransform
from rising.transforms.functional.painting import (
local_pixel_shuffle, random_inpainting, random_outpainting
)


__all__ = ["RandomInpainting", "RandomOutpainting", "RandomInOrOutpainting", "LocalPixelShuffle"]


class LocalPixelShuffle(BaseTransform):
""" Shuffels Pixels locally in n patches,
as proposed in Models Genesis """

def __init__(self, n: int = -1, block_size: tuple = (0, 0, 0), rel_block_size: float = 0.1,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Args:
n: number of local patches to shuffle, default = 1000*channels
block_size: size of local patches in pixel
rel_block_size: size of local patches in relation to image size, e.g. image_size=(32,192,192) and rel_block_size=0.25 will result in patches of size (8, 48, 48). If rel_block_size > 0, it will overwrite block_size.
keys: the keys corresponding to the values to distort
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=local_pixel_shuffle, n=n, block_size=block_size, rel_block_size=rel_block_size,
keys=keys, grad=grad, **kwargs)


class RandomInpainting(BaseTransform):
""" In n local areas, the image is replaced by uniform noise in range (minv, maxv),
as proposed in Models Genesis """

def __init__(self, n: int = 5,
maxv: float = 1.0, minv: float = 0.0,
max_size: tuple = (0, 0, 0), min_size: tuple = (0, 0, 0), rel_max_size: tuple = (0.25, 0.25, 0.25), rel_min_size: tuple = (0.1, 0.1, 0.1), min_border_distance: tuple = (3, 3, 3),
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Args:
minv, maxv: range of uniform noise
n: number of local patches to randomize
max_size: absolute upper bound for the patch size
min_size: absolute lower bound for the patch size
rel_max_size: relative upper bound for the patch size, relative to image_size. Overwrites max_size.
rel_min_size: relative lower bound for the patch size, relative to image_size. Overwrites min_size.
min_border_distance: the minimum distance of patches to the border in pixel for each dimension.
keys: the keys corresponding to the values to distort
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=random_inpainting, n=n, maxv=maxv, minv=minv, max_size=max_size, min_size=min_size, rel_max_size=rel_max_size, rel_min_size=rel_min_size,
keys=keys, grad=grad, **kwargs)


class RandomOutpainting(AbstractTransform):
""" The border of the images will be replaced by uniform noise,
as proposed in Models Genesis. (Replaces a patch in an equally sized noise image with the corresponding input image content) """

def __init__(self, prob: float = 0.5, maxv: float = 1.0, minv: float = 0.0,
max_size: tuple = (0, 0, 0), min_size: tuple = (0, 0, 0),
rel_max_size: tuple = (6 / 7, 6 / 7, 6 / 7), rel_min_size: tuple = (5 / 7, 5 / 7, 5 / 7), min_border_distance: tuple = (3, 3, 3),
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Args:
minv, maxv: range of uniform noise
prob: probability of outpainting. For prob<1.0, not all images will be augmented
max_size: absolute upper bound for the patch size. Here the patch is the remaining image
min_size: absolute lower bound for the patch size. Here the patch is the remaining image
rel_max_size: relative upper bound for the patch size, relative to image_size. Overwrites max_size.
rel_min_size: relative lower bound for the patch size, relative to image_size. Overwrites min_size.
min_border_distance: the minimum thickness of the border in pixel for each dimension.
keys: the keys corresponding to the values to distort
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(grad=grad, **kwargs)
self.prob = prob
self.maxv = maxv
self.minv = minv
self.keys = keys
self.max_size = max_size
self.min_size = min_size
self.rel_min_size = rel_min_size
self.rel_max_size = rel_max_size
self.min_border_distance = min_border_distance

def forward(self, **data) -> dict:
if torch.rand(1) < self.prob:
for key in self.keys:
data[key] = random_outpainting(data[key], maxv=self.maxv, minv=self.minv, max_size=self.max_size, min_size=self.min_size, rel_max_size=self.rel_max_size, rel_min_size=self.rel_min_size,
min_border_distance=self.min_border_distance)
return data


class RandomInOrOutpainting(AbstractTransform):
"""Applies either random inpainting or random outpainting to the image,
as proposed in Models Genesis """

def __init__(self, prob: float = 0.5, n: int = 5,
maxv: float = 1.0, minv: float = 0.0,
max_size_in: tuple = (0, 0, 0), min_size_in: tuple = (0, 0, 0), rel_max_size_in: tuple = (0.25, 0.25, 0.25), rel_min_size_in: tuple = (0.1, 0.1, 0.1),
max_size_out: tuple = (0, 0, 0), min_size_out: tuple = (0, 0, 0),
rel_max_size_out: tuple = (6 / 7, 6 / 7, 6 / 7), rel_min_size_out: tuple = (5 / 7, 5 / 7, 5 / 7),
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Args:
minv, maxv: range of uniform noise
prob: probability of outpainting, probability of inpainting is 1-prob.
n: number of local patches to randomize in case of inpainting
keys: the keys corresponding to the values to distort
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""
super().__init__(grad=grad, **kwargs)
self.prob = prob
self.maxv = maxv
self.minv = minv
self.keys = keys
self.n = n
self.max_size_in = max_size_in
self.min_size_in = min_size_in
self.rel_min_size_in = rel_min_size_in
self.rel_max_size_in = rel_max_size_in
self.max_size_out = max_size_out
self.min_size_out = min_size_out
self.rel_min_size_out = rel_min_size_out
self.rel_max_size_out = rel_max_size_out

def forward(self, **data) -> dict:
if torch.rand(1) < self.prob:
for key in self.keys:
data[key] = random_outpainting(data[key], maxv=self.maxv, minv=self.minv, max_size=self.max_size_out,
min_size=self.min_size_out, rel_max_size=self.rel_max_size_out, rel_min_size=self.rel_min_size_out)
else:
for key in self.keys:
data[key] = random_inpainting(data[key], n=self.n, maxv=self.maxv, minv=self.minv, max_size=self.max_size_in,
min_size=self.min_size_in, rel_max_size=self.rel_max_size_in, rel_min_size=self.rel_min_size_in)
return data
Loading