diff --git a/pytorch3dunet/datasets/hdf5.py b/pytorch3dunet/datasets/hdf5.py index 995e96ba..ccc8ca8f 100644 --- a/pytorch3dunet/datasets/hdf5.py +++ b/pytorch3dunet/datasets/hdf5.py @@ -7,7 +7,7 @@ import h5py import pytorch3dunet.augment.transforms as transforms -from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad +from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad, RandomScaler from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('HDF5Dataset') @@ -45,10 +45,14 @@ class AbstractHDF5Dataset(ConfigDataset): label_internal_path (str or list): H5 internal path to the label dataset weight_internal_path (str or list): H5 internal path to the per pixel weights (optional) global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset + random_scale (int): if not None, the raw data will be randomly shifted by a value in the range + [-random_scale, random_scale] in each dimension and then scaled to the original patch shape + """ def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', - label_internal_path='label', weight_internal_path=None, global_normalization=True): + label_internal_path='label', weight_internal_path=None, global_normalization=True, + random_scale=None): assert phase in ['train', 'val', 'test'] self.phase = phase @@ -95,6 +99,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r with h5py.File(file_path, 'r') as f: raw = f[raw_internal_path] + if raw.ndim == 3: + self.volume_shape = raw.shape + else: + self.volume_shape = raw.shape[1:] label = f[label_internal_path] if phase != 'test' else None weight_map = f[weight_internal_path] if weight_internal_path is not None else None # build slice indices for raw and label data sets @@ -103,6 +111,17 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r self.label_slices = slice_builder.label_slices self.weight_slices = slice_builder.weight_slices + if random_scale is not None: + assert isinstance(random_scale, int), 'random_scale must be an integer' + stride_shape = slice_builder_config.get('stride_shape') + assert all(random_scale < stride for stride in stride_shape), \ + f"random_scale {random_scale} must be smaller than each of the strides {stride_shape}" + patch_shape = slice_builder_config.get('patch_shape') + self.random_scaler = RandomScaler(random_scale, patch_shape, self.volume_shape) + logger.info(f"Using RandomScaler with offset range {random_scale}") + else: + self.random_scaler = None + self.patch_count = len(self.raw_slices) @abstractmethod @@ -121,14 +140,6 @@ def get_weight_patch(self, idx): def get_raw_padded_patch(self, idx): raise NotImplementedError - def volume_shape(self): - with h5py.File(self.file_path, 'r') as f: - raw = f[self.raw_internal_path] - if raw.ndim == 3: - return raw.shape - else: - return raw.shape[1:] - def __getitem__(self, idx): if idx >= len(self): raise StopIteration @@ -146,15 +157,24 @@ def __getitem__(self, idx): raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded)) return raw_patch_transformed, raw_idx else: - raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) - - # get the slice for a given index 'idx' label_idx = self.label_slices[idx] + + if self.random_scaler is not None: + # randomize the indices + raw_idx, label_idx = self.random_scaler.randomize_indices(raw_idx, label_idx) + + raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) label_patch_transformed = self.label_transform(self.get_label_patch(label_idx)) if self.weight_internal_path is not None: weight_idx = self.weight_slices[idx] weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx)) return raw_patch_transformed, label_patch_transformed, weight_patch_transformed + + if self.random_scaler is not None: + # scale patches back to the original patch size + raw_patch_transformed, label_patch_transformed = self.random_scaler.rescale_patches( + raw_patch_transformed, label_patch_transformed + ) # return the transformed raw and label patches return raw_patch_transformed, label_patch_transformed @@ -204,7 +224,8 @@ def create_datasets(cls, dataset_config, phase): raw_internal_path=dataset_config.get('raw_internal_path', 'raw'), label_internal_path=dataset_config.get('label_internal_path', 'label'), weight_internal_path=dataset_config.get('weight_internal_path', None), - global_normalization=dataset_config.get('global_normalization', None)) + global_normalization=dataset_config.get('global_normalization', None), + random_scale=dataset_config.get('random_scale', None)) futures.append(future) datasets = [] @@ -225,11 +246,11 @@ class StandardHDF5Dataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=True): + global_normalization=True, random_scale=None): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) + global_normalization=global_normalization, random_scale=random_scale) self._raw = None self._raw_padded = None self._label = None @@ -269,11 +290,11 @@ class LazyHDF5Dataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=False): + global_normalization=False, random_scale=None): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) + global_normalization=global_normalization, random_scale=random_scale) logger.info("Using LazyHDF5Dataset") diff --git a/pytorch3dunet/datasets/utils.py b/pytorch3dunet/datasets/utils.py index ad6ddc30..addbf646 100644 --- a/pytorch3dunet/datasets/utils.py +++ b/pytorch3dunet/datasets/utils.py @@ -3,6 +3,7 @@ import numpy as np import torch +from torch.nn.functional import interpolate from torch.utils.data import DataLoader, ConcatDataset, Dataset from pytorch3dunet.unet3d.utils import get_logger, get_class @@ -10,7 +11,108 @@ logger = get_logger('Dataset') +class RandomScaler: + """ + Randomly scales the raw and label patches. + """ + + def __init__(self, scale_range: int, patch_shape: tuple, volume_shape: tuple, seed: int = 47): + self.scale_range = scale_range + self.patch_shape = patch_shape + self.volume_shape = volume_shape + self.rs = np.random.RandomState(seed) + + def randomize_indices(self, raw_idx: tuple, label_idx: tuple) -> tuple[tuple, tuple]: + # select random offsets for scaling + offsets = [self.rs.randint(self.scale_range) for _ in range(3)] + # change offset sign at random + if self.rs.rand() > 0.5: + offsets = [-o for o in offsets] + # apply offsets to the start or end of the slice at random + is_start = self.rs.rand() > 0.5 + raw_idx = self._apply_offsets(raw_idx, offsets, is_start) + label_idx = self._apply_offsets(label_idx, offsets, is_start) + return raw_idx, label_idx + + def rescale_patches(self, raw_patch: torch.Tensor, label_patch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # compute zoom factors + if raw_patch.ndim == 4: + raw_shape = raw_patch.shape[1:] + else: + raw_shape = raw_patch.shape + + # if raw_shape equal to self.patch_shape just return the patches + if raw_shape == self.patch_shape: + return raw_patch, label_patch + + # rescale patches back to the original shape + if raw_patch.ndim == 4: + # add batch dimension + raw_patch = raw_patch.unsqueeze(0) + remove_dims = 1 + else: + # add batch and channels dimensions + raw_patch = raw_patch.unsqueeze(0).unsqueeze(0) + remove_dims = 2 + + raw_patch = interpolate(raw_patch, self.patch_shape, mode='trilinear') + # remove additional dimensions + for _ in range(remove_dims): + raw_patch = raw_patch.squeeze(0) + + if label_patch.ndim == 4: + label_patch = label_patch.unsqueeze(0) + remove_dims = 1 + else: + label_patch = label_patch.unsqueeze(0).unsqueeze(0) + remove_dims = 2 + + label_dtype = label_patch.dtype + # check if label patch is of torch int type + if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]: + # convert to float for interpolation + label_patch = label_patch.float() + + label_patch = interpolate(label_patch, self.patch_shape, mode='nearest') + + # remove additional dimensions + for _ in range(remove_dims): + label_patch = label_patch.squeeze(0) + + # conver back to int if necessary + if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]: + if label_dtype == torch.int64: + label_patch = label_patch.long() + else: + label_patch = label_patch.int() + + return raw_patch, label_patch + + def _apply_offsets(self, idx: tuple, offsets: list, is_start: bool) -> tuple: + if len(idx) == 4: + offsets = [0] + offsets + volume_shape = (idx[0].stop,) + self.volume_shape + else: + volume_shape = self.volume_shape + + new_idx = [] + for i, o, s in zip(idx, offsets, volume_shape): + if is_start: + start = max(0, i.start + o) + stop = i.stop + else: + start = i.start + stop = min(s, i.stop + o) + + new_idx.append(slice(start, stop)) + return tuple(new_idx) + + class ConfigDataset(Dataset): + """ + Abstract class for datasets that are configured via a dictionary. + """ + def __getitem__(self, index): raise NotImplementedError diff --git a/pytorch3dunet/unet3d/predictor.py b/pytorch3dunet/unet3d/predictor.py index 710d4ca9..d2cb2025 100644 --- a/pytorch3dunet/unet3d/predictor.py +++ b/pytorch3dunet/unet3d/predictor.py @@ -97,7 +97,7 @@ def __call__(self, test_loader): logger.info(f'Running inference on {len(test_loader)} batches') # dimensionality of the output predictions - volume_shape = test_loader.dataset.volume_shape() + volume_shape = test_loader.dataset.volume_shape if self.save_segmentation: # single channel segmentation map diff --git a/tests/resources/transformer_config.yml b/tests/resources/transformer_config.yml index f48c6e6d..86c7b919 100644 --- a/tests/resources/transformer_config.yml +++ b/tests/resources/transformer_config.yml @@ -5,28 +5,20 @@ train: - name: Standardize - name: RandomFlip - name: RandomRotate90 - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - name: ElasticDeformation execution_probability: 1.0 spline_order: 0 - name: ToTensor expand_dims: true label: - - name: Standardize - name: RandomFlip - name: RandomRotate90 - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - name: ElasticDeformation execution_probability: 1.0 spline_order: 0 - name: ToTensor expand_dims: true + dtype: int64 test: transformer: @@ -36,4 +28,5 @@ test: expand_dims: true label: - name: ToTensor - expand_dims: true \ No newline at end of file + expand_dims: true + dtype: int64 \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 14ffbd73..e0107771 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -144,6 +144,30 @@ def test_halo(self): input_ = remove_padding(input_, halo_shape) assert np.allclose(input_[0], raw[indices]) + def test_random_scale(self, transformer_config): + path = create_random_dataset((200, 200, 172)) + + patch_shapes = [(172, 172, 172)] + stride_shapes = [(28, 28, 28)] + + phase = 'train' + + for patch_shape, stride_shape in zip(patch_shapes, stride_shapes): + dataset = StandardHDF5Dataset(path, phase=phase, + slice_builder_config=_slice_builder_conf(patch_shape, stride_shape), + transformer_config=transformer_config[phase]['transformer'], + raw_internal_path='raw', + label_internal_path='label', + random_scale=20) + + for raw, label in dataset: + if raw.ndim == 3: + assert raw.shape == patch_shape + assert label.shape == patch_shape + else: + assert raw.shape[1:] == patch_shape + assert label.shape[1:] == patch_shape + def create_random_dataset(shape, ignore_index=False, raw_datasets=None, label_datasets=None): if label_datasets is None: