Skip to content

Commit

Permalink
add example for PixelwiseCrossEntropyLoss training
Browse files Browse the repository at this point in the history
  • Loading branch information
wolny committed Jan 3, 2024
1 parent 494677d commit c2c48b8
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 4 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,16 @@ CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>
- `DiceLoss` (standard `DiceLoss` defined as `1 - DiceCoefficient` used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the `DiceLoss` per channel and averages the values)
- `BCEDiceLoss` (Linear combination of BCE and Dice losses, i.e. `alpha * BCE + beta * Dice`, `alpha, beta` can be specified in the `loss` section of the config)
- `CrossEntropyLoss` (one can specify class weights via the `weight: [w_1, ..., w_k]` in the `loss` section of the config)
- `PixelWiseCrossEntropyLoss` (one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth; `weight` dataset has to be provided in the H5 files for training and validatin)
- `PixelWiseCrossEntropyLoss` (one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth; `weight` dataset has to be provided in the H5 files for training and validation; see sample config in [train_config.yml](resources/3DUnet_confocal_boundary_weighted/train_config.yml)
- `WeightedCrossEntropyLoss` (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation)
- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard `DiceLoss`.
- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise, use standard `DiceLoss`.

For a detailed explanation of some of the supported loss functions see:
[Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations](https://arxiv.org/pdf/1707.03237.pdf).

### Regression
- `MSELoss` (mean squared error loss)
- `L1Loss` (mean absolute errro loss)
- `L1Loss` (mean absolute error loss)
- `SmoothL1Loss` (less sensitive to outliers than MSELoss)
- `WeightedSmoothL1Loss` (extension of the `SmoothL1Loss` which allows to weight the voxel values above/below a given threshold differently)

Expand All @@ -131,7 +131,7 @@ For a detailed explanation of some of the supported loss functions see:
- `MeanIoU` (mean intersection over union)
- `DiceCoefficient` (computes per channel Dice Coefficient and returns the average)
If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics
(the metrics below are computed by running connected components on thresholded boundary map and comparing the resulted instances to the ground truth instance segmentation):
(the metrics below are computed by running connected components on threshold boundary map and comparing the resulted instances to the ground truth instance segmentation):
- `BoundaryAveragePrecision` (Average Precision applied to the boundary probability maps: thresholds the output from the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth)
- `AdaptedRandError` (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
- `AveragePrecision` (see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric)
Expand Down
40 changes: 40 additions & 0 deletions resources/3DUnet_confocal_boundary_weighted/test_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Download test data from: https://osf.io/8jz7e/
model_path: PATH_TO_BEST_CHECKPOINT
model:
name: UNet3D
# number of input channels to the model
in_channels: 1
# number of output channels (foreground and background will be predicted in separate channels)
out_channels: 2
# determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
layer_order: gcr
# initial number of feature maps
f_maps: 32
# number of groups in the groupnorm
num_groups: 8
# apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
final_sigmoid: true
predictor:
name: 'StandardPredictor'
loaders:
# save predictions to output_dir
output_dir: PATH_TO_OUTPUT_DIR
# batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel
batch_size: 1
# how many subprocesses to use for data loading
num_workers: 8
# test loaders configuration
test:
file_paths:
- PATH_TO_TEST_DIR

slice_builder:
name: SliceBuilder
patch_shape: [80, 170, 170]
stride_shape: [40, 90, 90]

transformer:
raw:
- name: Standardize
- name: ToTensor
expand_dims: true
172 changes: 172 additions & 0 deletions resources/3DUnet_confocal_boundary_weighted/train_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Sample configuration file for training a 3D U-Net on a task of predicting the boundaries in 3D stack of the Arabidopsis
# ovules acquired with the confocal microscope.
# Training done using a PixelWiseCrossEntropyLoss with a weight map that focuses on faint boundaries.
model:
name: UNet3D
# number of input channels to the model
in_channels: 1
# number of output channels (since cross-entropy loss is used, foreground and background classes are represented in separate channels)
out_channels: 2
# determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
layer_order: gcr
# initial number of feature maps
f_maps: 32
# number of groups in the groupnorm
num_groups: 8
# apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
final_sigmoid: false
# loss function to be used during training
loss:
name: PixelWiseCrossEntropyLoss
# skip the last channel in the target (i.e. when last channel contains data not relevant for the loss)
skip_last_target: true
# squeeze the channel dimension in the target
squeeze_channel: true
# a target value that is ignored and does not contribute to the input gradient
ignore_index: null
optimizer:
# initial learning rate
learning_rate: 0.0002
# weight decay
weight_decay: 0.00001
# evaluation metric
eval_metric:
# use AdaptedRandError metric
name: BoundaryAdaptedRandError
# probability maps threshold
threshold: 0.4
# use the last target channel to compute the metric
use_last_target: true
# use only the first channel for computing the metric
use_first_input: true
lr_scheduler:
name: ReduceLROnPlateau
# make sure to use the 'min' mode cause lower AdaptedRandError is better
mode: min
factor: 0.5
patience: 30
trainer:
# model with lower eval score is considered better
eval_score_higher_is_better: False
# path to the checkpoint directory
checkpoint_dir: CHECKPOINT_DIR
# path to latest checkpoint; if provided the training will be resumed from that checkpoint
resume: null
# path to the best_checkpoint.pytorch; to be used for fine-tuning the model with additional ground truth
pre_trained: null
# how many iterations between validations
validate_after_iters: 1000
# how many iterations between tensorboard logging
log_after_iters: 500
# max number of epochs
max_num_epochs: 1000
# max number of iterations
max_num_iterations: 150000
# Configure training and validation loaders
loaders:
# how many subprocesses to use for data loading
num_workers: 8
# path to the raw data within the H5
raw_internal_path: /raw
# path to the label data within the H5
label_internal_path: /label
# path to the weight data within the H5
weight_internal_path: /weight
# configuration of the train loader
train:
# path to the training datasets
file_paths:
- PATH_TO_TRAIN_DIR

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
name: FilterSliceBuilder
# train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
patch_shape: [80, 170, 170]
# train stride between patches
stride_shape: [20, 40, 40]
# minimum volume of the labels in the patch
threshold: 0.6
# probability of accepting patches which do not fulfil the threshold criterion
slack_acceptance: 0.01

transformer:
raw:
- name: Standardize
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY plane due to anisotropy
axes: [[2, 1]]
angle_spectrum: 45
mode: reflect
- name: ElasticDeformation
spline_order: 3
- name: GaussianBlur3D
execution_probability: 0.5
- name: AdditiveGaussianNoise
execution_probability: 0.2
- name: AdditivePoissonNoise
execution_probability: 0.2
- name: ToTensor
expand_dims: true
label:
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY plane due to anisotropy
axes: [[2, 1]]
angle_spectrum: 45
mode: reflect
- name: ElasticDeformation
spline_order: 0
- name: StandardLabelToBoundary
# append original ground truth labels to the last channel (to be able to compute the eval metric)
append_label: true
- name: ToTensor
expand_dims: false
weight:
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY plane due to anisotropy
axes: [[2, 1]]
angle_spectrum: 45
mode: reflect
- name: ElasticDeformation
spline_order: 3
- name: ToTensor
expand_dims: false

# configuration of the val loader
val:
# path to the val datasets
file_paths:
- PATH_TO_VAL_DIR

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
name: FilterSliceBuilder
# train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
patch_shape: [80, 170, 170]
# train stride between patches
stride_shape: [80, 170, 170]
# minimum volume of the labels in the patch
threshold: 0.6
# probability of accepting patches which do not fulfil the threshold criterion
slack_acceptance: 0.01

# data augmentation
transformer:
raw:
- name: Standardize
- name: ToTensor
expand_dims: true
label:
- name: StandardLabelToBoundary
append_label: true
- name: ToTensor
expand_dims: false
weight:
- name: ToTensor
expand_dims: false

0 comments on commit c2c48b8

Please sign in to comment.