diff --git a/README.md b/README.md index c6f6ff77..032357db 100644 --- a/README.md +++ b/README.md @@ -111,16 +111,16 @@ CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config - `DiceLoss` (standard `DiceLoss` defined as `1 - DiceCoefficient` used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the `DiceLoss` per channel and averages the values) - `BCEDiceLoss` (Linear combination of BCE and Dice losses, i.e. `alpha * BCE + beta * Dice`, `alpha, beta` can be specified in the `loss` section of the config) - `CrossEntropyLoss` (one can specify class weights via the `weight: [w_1, ..., w_k]` in the `loss` section of the config) -- `PixelWiseCrossEntropyLoss` (one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth; `weight` dataset has to be provided in the H5 files for training and validatin) +- `PixelWiseCrossEntropyLoss` (one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth; `weight` dataset has to be provided in the H5 files for training and validation; see sample config in [train_config.yml](resources/3DUnet_confocal_boundary_weighted/train_config.yml) - `WeightedCrossEntropyLoss` (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation) -- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard `DiceLoss`. +- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise, use standard `DiceLoss`. For a detailed explanation of some of the supported loss functions see: [Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations](https://arxiv.org/pdf/1707.03237.pdf). ### Regression - `MSELoss` (mean squared error loss) -- `L1Loss` (mean absolute errro loss) +- `L1Loss` (mean absolute error loss) - `SmoothL1Loss` (less sensitive to outliers than MSELoss) - `WeightedSmoothL1Loss` (extension of the `SmoothL1Loss` which allows to weight the voxel values above/below a given threshold differently) @@ -131,7 +131,7 @@ For a detailed explanation of some of the supported loss functions see: - `MeanIoU` (mean intersection over union) - `DiceCoefficient` (computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics -(the metrics below are computed by running connected components on thresholded boundary map and comparing the resulted instances to the ground truth instance segmentation): +(the metrics below are computed by running connected components on threshold boundary map and comparing the resulted instances to the ground truth instance segmentation): - `BoundaryAveragePrecision` (Average Precision applied to the boundary probability maps: thresholds the output from the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth) - `AdaptedRandError` (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation) - `AveragePrecision` (see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric) diff --git a/resources/3DUnet_confocal_boundary_weighted/test_config.yml b/resources/3DUnet_confocal_boundary_weighted/test_config.yml new file mode 100644 index 00000000..3dc29441 --- /dev/null +++ b/resources/3DUnet_confocal_boundary_weighted/test_config.yml @@ -0,0 +1,40 @@ +# Download test data from: https://osf.io/8jz7e/ +model_path: PATH_TO_BEST_CHECKPOINT +model: + name: UNet3D + # number of input channels to the model + in_channels: 1 + # number of output channels (foreground and background will be predicted in separate channels) + out_channels: 2 + # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm) + layer_order: gcr + # initial number of feature maps + f_maps: 32 + # number of groups in the groupnorm + num_groups: 8 + # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax + final_sigmoid: true +predictor: + name: 'StandardPredictor' +loaders: + # save predictions to output_dir + output_dir: PATH_TO_OUTPUT_DIR + # batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel + batch_size: 1 + # how many subprocesses to use for data loading + num_workers: 8 + # test loaders configuration + test: + file_paths: + - PATH_TO_TEST_DIR + + slice_builder: + name: SliceBuilder + patch_shape: [80, 170, 170] + stride_shape: [40, 90, 90] + + transformer: + raw: + - name: Standardize + - name: ToTensor + expand_dims: true diff --git a/resources/3DUnet_confocal_boundary_weighted/train_config.yml b/resources/3DUnet_confocal_boundary_weighted/train_config.yml new file mode 100644 index 00000000..71321c39 --- /dev/null +++ b/resources/3DUnet_confocal_boundary_weighted/train_config.yml @@ -0,0 +1,172 @@ +# Sample configuration file for training a 3D U-Net on a task of predicting the boundaries in 3D stack of the Arabidopsis +# ovules acquired with the confocal microscope. +# Training done using a PixelWiseCrossEntropyLoss with a weight map that focuses on faint boundaries. +model: + name: UNet3D + # number of input channels to the model + in_channels: 1 + # number of output channels (since cross-entropy loss is used, foreground and background classes are represented in separate channels) + out_channels: 2 + # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm) + layer_order: gcr + # initial number of feature maps + f_maps: 32 + # number of groups in the groupnorm + num_groups: 8 + # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax + final_sigmoid: false +# loss function to be used during training +loss: + name: PixelWiseCrossEntropyLoss + # skip the last channel in the target (i.e. when last channel contains data not relevant for the loss) + skip_last_target: true + # squeeze the channel dimension in the target + squeeze_channel: true + # a target value that is ignored and does not contribute to the input gradient + ignore_index: null +optimizer: + # initial learning rate + learning_rate: 0.0002 + # weight decay + weight_decay: 0.00001 +# evaluation metric +eval_metric: + # use AdaptedRandError metric + name: BoundaryAdaptedRandError + # probability maps threshold + threshold: 0.4 + # use the last target channel to compute the metric + use_last_target: true + # use only the first channel for computing the metric + use_first_input: true +lr_scheduler: + name: ReduceLROnPlateau + # make sure to use the 'min' mode cause lower AdaptedRandError is better + mode: min + factor: 0.5 + patience: 30 +trainer: + # model with lower eval score is considered better + eval_score_higher_is_better: False + # path to the checkpoint directory + checkpoint_dir: CHECKPOINT_DIR + # path to latest checkpoint; if provided the training will be resumed from that checkpoint + resume: null + # path to the best_checkpoint.pytorch; to be used for fine-tuning the model with additional ground truth + pre_trained: null + # how many iterations between validations + validate_after_iters: 1000 + # how many iterations between tensorboard logging + log_after_iters: 500 + # max number of epochs + max_num_epochs: 1000 + # max number of iterations + max_num_iterations: 150000 +# Configure training and validation loaders +loaders: + # how many subprocesses to use for data loading + num_workers: 8 + # path to the raw data within the H5 + raw_internal_path: /raw + # path to the label data within the H5 + label_internal_path: /label + # path to the weight data within the H5 + weight_internal_path: /weight + # configuration of the train loader + train: + # path to the training datasets + file_paths: + - PATH_TO_TRAIN_DIR + + # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch + slice_builder: + name: FilterSliceBuilder + # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + patch_shape: [80, 170, 170] + # train stride between patches + stride_shape: [20, 40, 40] + # minimum volume of the labels in the patch + threshold: 0.6 + # probability of accepting patches which do not fulfil the threshold criterion + slack_acceptance: 0.01 + + transformer: + raw: + - name: Standardize + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + # rotate only in ZY plane due to anisotropy + axes: [[2, 1]] + angle_spectrum: 45 + mode: reflect + - name: ElasticDeformation + spline_order: 3 + - name: GaussianBlur3D + execution_probability: 0.5 + - name: AdditiveGaussianNoise + execution_probability: 0.2 + - name: AdditivePoissonNoise + execution_probability: 0.2 + - name: ToTensor + expand_dims: true + label: + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + # rotate only in ZY plane due to anisotropy + axes: [[2, 1]] + angle_spectrum: 45 + mode: reflect + - name: ElasticDeformation + spline_order: 0 + - name: StandardLabelToBoundary + # append original ground truth labels to the last channel (to be able to compute the eval metric) + append_label: true + - name: ToTensor + expand_dims: false + weight: + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + # rotate only in ZY plane due to anisotropy + axes: [[2, 1]] + angle_spectrum: 45 + mode: reflect + - name: ElasticDeformation + spline_order: 3 + - name: ToTensor + expand_dims: false + + # configuration of the val loader + val: + # path to the val datasets + file_paths: + - PATH_TO_VAL_DIR + + # SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch + slice_builder: + name: FilterSliceBuilder + # train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better) + patch_shape: [80, 170, 170] + # train stride between patches + stride_shape: [80, 170, 170] + # minimum volume of the labels in the patch + threshold: 0.6 + # probability of accepting patches which do not fulfil the threshold criterion + slack_acceptance: 0.01 + + # data augmentation + transformer: + raw: + - name: Standardize + - name: ToTensor + expand_dims: true + label: + - name: StandardLabelToBoundary + append_label: true + - name: ToTensor + expand_dims: false + weight: + - name: ToTensor + expand_dims: false \ No newline at end of file