Skip to content

Commit

Permalink
add RandomScaler augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
wolny committed Jan 5, 2025
1 parent f2fc1fc commit e9454c9
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 29 deletions.
57 changes: 39 additions & 18 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import h5py

import pytorch3dunet.augment.transforms as transforms
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad, RandomScaler
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')
Expand Down Expand Up @@ -45,10 +45,14 @@ class AbstractHDF5Dataset(ConfigDataset):
label_internal_path (str or list): H5 internal path to the label dataset
weight_internal_path (str or list): H5 internal path to the per pixel weights (optional)
global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset
random_scale (int): if not None, the raw data will be randomly shifted by a value in the range
[-random_scale, random_scale] in each dimension and then scaled to the original patch shape
"""

def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw',
label_internal_path='label', weight_internal_path=None, global_normalization=True):
label_internal_path='label', weight_internal_path=None, global_normalization=True,
random_scale=None):
assert phase in ['train', 'val', 'test']

self.phase = phase
Expand Down Expand Up @@ -95,6 +99,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r

with h5py.File(file_path, 'r') as f:
raw = f[raw_internal_path]
if raw.ndim == 3:
self.volume_shape = raw.shape
else:
self.volume_shape = raw.shape[1:]
label = f[label_internal_path] if phase != 'test' else None
weight_map = f[weight_internal_path] if weight_internal_path is not None else None
# build slice indices for raw and label data sets
Expand All @@ -103,6 +111,17 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.label_slices = slice_builder.label_slices
self.weight_slices = slice_builder.weight_slices

if random_scale is not None:
assert isinstance(random_scale, int), 'random_scale must be an integer'
stride_shape = slice_builder_config.get('stride_shape')
assert all(random_scale < stride for stride in stride_shape), \
f"random_scale {random_scale} must be smaller than each of the strides {stride_shape}"
patch_shape = slice_builder_config.get('patch_shape')
self.random_scaler = RandomScaler(random_scale, patch_shape, self.volume_shape)
logger.info(f"Using RandomScaler with offset range {random_scale}")
else:
self.random_scaler = None

self.patch_count = len(self.raw_slices)

@abstractmethod
Expand All @@ -121,14 +140,6 @@ def get_weight_patch(self, idx):
def get_raw_padded_patch(self, idx):
raise NotImplementedError

def volume_shape(self):
with h5py.File(self.file_path, 'r') as f:
raw = f[self.raw_internal_path]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]

def __getitem__(self, idx):
if idx >= len(self):
raise StopIteration
Expand All @@ -146,15 +157,24 @@ def __getitem__(self, idx):
raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
return raw_patch_transformed, raw_idx
else:
raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))

# get the slice for a given index 'idx'
label_idx = self.label_slices[idx]

if self.random_scaler is not None:
# randomize the indices
raw_idx, label_idx = self.random_scaler.randomize_indices(raw_idx, label_idx)

raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))
label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
if self.weight_internal_path is not None:
weight_idx = self.weight_slices[idx]
weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
return raw_patch_transformed, label_patch_transformed, weight_patch_transformed

if self.random_scaler is not None:
# scale patches back to the original patch size
raw_patch_transformed, label_patch_transformed = self.random_scaler.rescale_patches(
raw_patch_transformed, label_patch_transformed
)
# return the transformed raw and label patches
return raw_patch_transformed, label_patch_transformed

Expand Down Expand Up @@ -204,7 +224,8 @@ def create_datasets(cls, dataset_config, phase):
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None),
global_normalization=dataset_config.get('global_normalization', None))
global_normalization=dataset_config.get('global_normalization', None),
random_scale=dataset_config.get('random_scale', None))
futures.append(future)

datasets = []
Expand All @@ -225,11 +246,11 @@ class StandardHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=True):
global_normalization=True, random_scale=None):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale)
self._raw = None
self._raw_padded = None
self._label = None
Expand Down Expand Up @@ -269,11 +290,11 @@ class LazyHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=False):
global_normalization=False, random_scale=None):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale)

logger.info("Using LazyHDF5Dataset")

Expand Down
102 changes: 102 additions & 0 deletions pytorch3dunet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,116 @@

import numpy as np
import torch
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader, ConcatDataset, Dataset

from pytorch3dunet.unet3d.utils import get_logger, get_class

logger = get_logger('Dataset')


class RandomScaler:
"""
Randomly scales the raw and label patches.
"""

def __init__(self, scale_range: int, patch_shape: tuple, volume_shape: tuple, seed: int = 47):
self.scale_range = scale_range
self.patch_shape = patch_shape
self.volume_shape = volume_shape
self.rs = np.random.RandomState(seed)

def randomize_indices(self, raw_idx: tuple, label_idx: tuple) -> tuple[tuple, tuple]:
# select random offsets for scaling
offsets = [self.rs.randint(self.scale_range) for _ in range(3)]
# change offset sign at random
if self.rs.rand() > 0.5:
offsets = [-o for o in offsets]
# apply offsets to the start or end of the slice at random
is_start = self.rs.rand() > 0.5
raw_idx = self._apply_offsets(raw_idx, offsets, is_start)
label_idx = self._apply_offsets(label_idx, offsets, is_start)
return raw_idx, label_idx

def rescale_patches(self, raw_patch: torch.Tensor, label_patch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# compute zoom factors
if raw_patch.ndim == 4:
raw_shape = raw_patch.shape[1:]
else:
raw_shape = raw_patch.shape

# if raw_shape equal to self.patch_shape just return the patches
if raw_shape == self.patch_shape:
return raw_patch, label_patch

# rescale patches back to the original shape
if raw_patch.ndim == 4:
# add batch dimension
raw_patch = raw_patch.unsqueeze(0)
remove_dims = 1
else:
# add batch and channels dimensions
raw_patch = raw_patch.unsqueeze(0).unsqueeze(0)
remove_dims = 2

raw_patch = interpolate(raw_patch, self.patch_shape, mode='trilinear')
# remove additional dimensions
for _ in range(remove_dims):
raw_patch = raw_patch.squeeze(0)

if label_patch.ndim == 4:
label_patch = label_patch.unsqueeze(0)
remove_dims = 1
else:
label_patch = label_patch.unsqueeze(0).unsqueeze(0)
remove_dims = 2

label_dtype = label_patch.dtype
# check if label patch is of torch int type
if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]:
# convert to float for interpolation
label_patch = label_patch.float()

label_patch = interpolate(label_patch, self.patch_shape, mode='nearest')

# remove additional dimensions
for _ in range(remove_dims):
label_patch = label_patch.squeeze(0)

# conver back to int if necessary
if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]:
if label_dtype == torch.int64:
label_patch = label_patch.long()
else:
label_patch = label_patch.int()

return raw_patch, label_patch

def _apply_offsets(self, idx: tuple, offsets: list, is_start: bool) -> tuple:
if len(idx) == 4:
offsets = [0] + offsets
volume_shape = (idx[0].stop,) + self.volume_shape
else:
volume_shape = self.volume_shape

new_idx = []
for i, o, s in zip(idx, offsets, volume_shape):
if is_start:
start = max(0, i.start + o)
stop = i.stop
else:
start = i.start
stop = min(s, i.stop + o)

new_idx.append(slice(start, stop))
return tuple(new_idx)


class ConfigDataset(Dataset):
"""
Abstract class for datasets that are configured via a dictionary.
"""

def __getitem__(self, index):
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion pytorch3dunet/unet3d/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, test_loader):

logger.info(f'Running inference on {len(test_loader)} batches')
# dimensionality of the output predictions
volume_shape = test_loader.dataset.volume_shape()
volume_shape = test_loader.dataset.volume_shape

if self.save_segmentation:
# single channel segmentation map
Expand Down
13 changes: 3 additions & 10 deletions tests/resources/transformer_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,20 @@ train:
- name: Standardize
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
axes: [[2, 1]]
angle_spectrum: 30
mode: reflect
- name: ElasticDeformation
execution_probability: 1.0
spline_order: 0
- name: ToTensor
expand_dims: true
label:
- name: Standardize
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
axes: [[2, 1]]
angle_spectrum: 30
mode: reflect
- name: ElasticDeformation
execution_probability: 1.0
spline_order: 0
- name: ToTensor
expand_dims: true
dtype: int64

test:
transformer:
Expand All @@ -36,4 +28,5 @@ test:
expand_dims: true
label:
- name: ToTensor
expand_dims: true
expand_dims: true
dtype: int64
24 changes: 24 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,30 @@ def test_halo(self):
input_ = remove_padding(input_, halo_shape)
assert np.allclose(input_[0], raw[indices])

def test_random_scale(self, transformer_config):
path = create_random_dataset((200, 200, 172))

patch_shapes = [(172, 172, 172)]
stride_shapes = [(28, 28, 28)]

phase = 'train'

for patch_shape, stride_shape in zip(patch_shapes, stride_shapes):
dataset = StandardHDF5Dataset(path, phase=phase,
slice_builder_config=_slice_builder_conf(patch_shape, stride_shape),
transformer_config=transformer_config[phase]['transformer'],
raw_internal_path='raw',
label_internal_path='label',
random_scale=20)

for raw, label in dataset:
if raw.ndim == 3:
assert raw.shape == patch_shape
assert label.shape == patch_shape
else:
assert raw.shape[1:] == patch_shape
assert label.shape[1:] == patch_shape


def create_random_dataset(shape, ignore_index=False, raw_datasets=None, label_datasets=None):
if label_datasets is None:
Expand Down

0 comments on commit e9454c9

Please sign in to comment.