Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and improvements #121

Merged
merged 14 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ requirements:
build:
- python >=3.9
- pip
- setuptools

run:
- python >=3.9
Expand Down
41 changes: 34 additions & 7 deletions pytorch3dunet/augment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve
from skimage import measure
from skimage import measure, exposure
from skimage.filters import gaussian
from skimage.segmentation import find_boundaries

Expand Down Expand Up @@ -133,6 +133,27 @@ def __call__(self, m):
return m


class RandomGammaCorrection:
"""
Adjust contrast by scaling each voxel to `v ** gamma`.
"""

def __init__(self, random_state, gamma=(0.5, 1.5), execution_probability=0.1, **kwargs):
self.random_state = random_state
assert len(gamma) == 2
self.gamma = gamma
self.execution_probability = execution_probability

def __call__(self, m):
if self.random_state.uniform() < self.execution_probability:
# rescale intensity values to [0, 1]
m = exposure.rescale_intensity(m, out_range=(0, 1))
gamma = self.random_state.uniform(self.gamma[0], self.gamma[1])
return exposure.adjust_gamma(m, gamma)

return m


# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader
# remember to use spline_order=0 when transforming the labels
class ElasticDeformation:
Expand Down Expand Up @@ -576,12 +597,12 @@ def __call__(self, m):
# check if non None in self.min_value/self.max_value
# if present and if so copy value to min_value
if self.min_value is not None:
for i,v in enumerate(self.min_value):
for i, v in enumerate(self.min_value):
if v != 'None':
min_value[i] = v

if self.max_value is not None:
for i,v in enumerate(self.max_value):
for i, v in enumerate(self.max_value):
if v != 'None':
max_value[i] = v
else:
Expand All @@ -600,9 +621,9 @@ def __call__(self, m):
norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)

if self.norm01 is True:
return np.clip(norm_0_1, 0, 1)
return np.clip(norm_0_1, 0, 1)
else:
return np.clip(2 * norm_0_1 - 1, -1, 1)
return np.clip(2 * norm_0_1 - 1, -1, 1)


class AdditiveGaussianNoise:
Expand Down Expand Up @@ -640,18 +661,24 @@ class ToTensor:
Args:
expand_dims (bool): if True, adds a channel dimension to the input data
dtype (np.dtype): the desired output data type
normalize (bool): zero-one normalization of the input data
"""

def __init__(self, expand_dims, dtype=np.float32, **kwargs):
def __init__(self, expand_dims, dtype=np.float32, normalize=False, **kwargs):
self.expand_dims = expand_dims
self.dtype = dtype
self.normalize = normalize

def __call__(self, m):
assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
# add channel dimension
if self.expand_dims and m.ndim == 3:
m = np.expand_dims(m, axis=0)

if self.normalize:
# avoid division by zero
m = (m - np.min(m)) / (np.max(m) - np.min(m) + 1e-10)

return torch.from_numpy(m.astype(dtype=self.dtype))


Expand Down Expand Up @@ -706,7 +733,7 @@ def __call__(self, m):


class GaussianBlur3D:
def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs):
def __init__(self, sigma=(.1, 2.), execution_probability=0.5, **kwargs):
self.sigma = sigma
self.execution_probability = execution_probability

Expand Down
97 changes: 64 additions & 33 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import glob
import os
from abc import abstractmethod
from concurrent.futures.process import ProcessPoolExecutor
from itertools import chain

import h5py

import pytorch3dunet.augment.transforms as transforms
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad, RandomScaler
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')
Expand Down Expand Up @@ -44,10 +45,14 @@ class AbstractHDF5Dataset(ConfigDataset):
label_internal_path (str or list): H5 internal path to the label dataset
weight_internal_path (str or list): H5 internal path to the per pixel weights (optional)
global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset
random_scale (int): if not None, the raw data will be randomly shifted by a value in the range
[-random_scale, random_scale] in each dimension and then scaled to the original patch shape
random_scale_probability (float): probability of executing the random scale on a patch
"""

def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw',
label_internal_path='label', weight_internal_path=None, global_normalization=True):
label_internal_path='label', weight_internal_path=None, global_normalization=True,
random_scale=None, random_scale_probability=0.5):
assert phase in ['train', 'val', 'test']

self.phase = phase
Expand Down Expand Up @@ -94,6 +99,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r

with h5py.File(file_path, 'r') as f:
raw = f[raw_internal_path]
if raw.ndim == 3:
self.volume_shape = raw.shape
else:
self.volume_shape = raw.shape[1:]
label = f[label_internal_path] if phase != 'test' else None
weight_map = f[weight_internal_path] if weight_internal_path is not None else None
# build slice indices for raw and label data sets
Expand All @@ -102,8 +111,18 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.label_slices = slice_builder.label_slices
self.weight_slices = slice_builder.weight_slices

if random_scale is not None:
assert isinstance(random_scale, int), 'random_scale must be an integer'
stride_shape = slice_builder_config.get('stride_shape')
assert all(random_scale < stride for stride in stride_shape), \
f"random_scale {random_scale} must be smaller than each of the strides {stride_shape}"
patch_shape = slice_builder_config.get('patch_shape')
self.random_scaler = RandomScaler(random_scale, patch_shape, self.volume_shape, random_scale_probability)
logger.info(f"Using RandomScaler with offset range {random_scale}")
else:
self.random_scaler = None

self.patch_count = len(self.raw_slices)
logger.info(f'Number of patches: {self.patch_count}')

@abstractmethod
def get_raw_patch(self, idx):
Expand All @@ -121,14 +140,6 @@ def get_weight_patch(self, idx):
def get_raw_padded_patch(self, idx):
raise NotImplementedError

def volume_shape(self):
with h5py.File(self.file_path, 'r') as f:
raw = f[self.raw_internal_path]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]

def __getitem__(self, idx):
if idx >= len(self):
raise StopIteration
Expand All @@ -146,15 +157,24 @@ def __getitem__(self, idx):
raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
return raw_patch_transformed, raw_idx
else:
raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))

# get the slice for a given index 'idx'
label_idx = self.label_slices[idx]

if self.random_scaler is not None:
# randomize the indices
raw_idx, label_idx = self.random_scaler.randomize_indices(raw_idx, label_idx)

raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))
label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
if self.weight_internal_path is not None:
weight_idx = self.weight_slices[idx]
weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
return raw_patch_transformed, label_patch_transformed, weight_patch_transformed

if self.random_scaler is not None:
# scale patches back to the original patch size
raw_patch_transformed, label_patch_transformed = self.random_scaler.rescale_patches(
raw_patch_transformed, label_patch_transformed
)
# return the transformed raw and label patches
return raw_patch_transformed, label_patch_transformed

Expand Down Expand Up @@ -192,22 +212,31 @@ def create_datasets(cls, dataset_config, phase):
# are going to be included in the final file_paths
file_paths = traverse_h5_paths(file_paths)

datasets = []
for file_path in file_paths:
try:
# create datasets concurrently
with ProcessPoolExecutor() as executor:
futures = []
for file_path in file_paths:
logger.info(f'Loading {phase} set from: {file_path}...')
dataset = cls(file_path=file_path,
phase=phase,
slice_builder_config=slice_builder_config,
transformer_config=transformer_config,
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None),
global_normalization=dataset_config.get('global_normalization', None))
datasets.append(dataset)
except Exception:
logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
return datasets
future = executor.submit(cls, file_path=file_path,
phase=phase,
slice_builder_config=slice_builder_config,
transformer_config=transformer_config,
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None),
global_normalization=dataset_config.get('global_normalization', None),
random_scale=dataset_config.get('random_scale', None),
random_scale_probability=dataset_config.get('random_scale_probability', 0.5))
futures.append(future)

datasets = []
for future in futures:
try:
dataset = future.result()
datasets.append(dataset)
except Exception as e:
logger.error(f'Failed to load dataset: {e}')
return datasets


class StandardHDF5Dataset(AbstractHDF5Dataset):
Expand All @@ -218,11 +247,12 @@ class StandardHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=True):
global_normalization=True, random_scale=None, random_scale_probability=0.5):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale,
random_scale_probability=random_scale_probability)
self._raw = None
self._raw_padded = None
self._label = None
Expand Down Expand Up @@ -262,11 +292,12 @@ class LazyHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=False):
global_normalization=False, random_scale=None, random_scale_probability=0.5):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale,
random_scale_probability=random_scale_probability)

logger.info("Using LazyHDF5Dataset")

Expand Down
Loading
Loading