Skip to content

Commit

Permalink
Merge pull request #122 from wolny/report-test-scores
Browse files Browse the repository at this point in the history
Add optional metrics computation in the predictor
  • Loading branch information
wolny authored Jan 18, 2025
2 parents dcd4dec + ef99754 commit dd8248d
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 504 deletions.
99 changes: 74 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ PyTorch implementation of 3D U-Net and its variants:
The code allows for training the U-Net for both: **semantic segmentation** (binary and multi-class) and **regression** problems (e.g. de-noising, learning deconvolutions).

## 2D U-Net

2D U-Net is also supported, see [2DUnet_confocal](resources/2DUnet_confocal_boundary) or [2DUnet_dsb2018](resources/2DUnet_dsb2018/train_config.yml) for example configuration.
Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. `(1, Y, X)` instead of `(Y, X)`) , because data loading / data augmentation requires tensors of rank 3.
The 2D U-Net itself uses the standard 2D convolutional layers instead of 3D convolutions with kernel size `(1, 3, 3)` for performance reasons.

## Input Data Format

The input data should be stored in HDF5 files. The HDF5 files for training should contain two datasets: `raw` and `label`. Optionally, when training with `PixelWiseCrossEntropyLoss` one should provide `weight` dataset.
The `raw` dataset should contain the input data, while the `label` dataset the ground truth labels. The optional `weight` dataset should contain the values for weighting the loss function in different regions of the input and should be of the same size as `label` dataset.
The format of the `raw`/`label` datasets depends on whether the problem is 2D or 3D and whether the data is single-channel or multi-channel, see the table below:
Expand All @@ -32,26 +34,36 @@ The format of the `raw`/`label` datasets depends on whether the problem is 2D or
| single-channel | (1, Y, X) | (Z, Y, X) |
| multi-channel | (C, 1, Y, X) | (C, Z, Y, X) |


## Prerequisites

- NVIDIA GPU
- CUDA CuDNN

### Running on Windows/OSX
`pytorch-3dunet` is a cross-platform package and runs on Windows and OS X as well.

`pytorch-3dunet` is a cross-platform package and runs on Windows and OS X as well.

## Installation

- The easiest way to install `pytorch-3dunet` package is via conda/mamba:

```
conda install -c conda-forge mamba
mamba create -n pytorch-3dunet -c pytorch -c nvidia -c conda-forge pytorch pytorch-cuda=12.1 pytorch-3dunet
conda activate pytorch-3dunet
```

After installation the following commands are accessible within the conda environment:
`train3dunet` for training the network and `predict3dunet` for prediction (see below).

- One can also install directly from source:
- One can also install directly from source, i.e. go to the checkout directory and run:

```
pip install -e .
```

or

```
python setup.py install
```
Expand All @@ -64,7 +76,8 @@ Given that `pytorch-3dunet` package was installed via conda as described above,
```
train3dunet --config <CONFIG>
```
where `CONFIG` is the path to a YAML configuration file, which specifies all aspects of the training procedure.

where `CONFIG` is the path to a YAML configuration file, which specifies all aspects of the training procedure.

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the config.

Expand All @@ -75,82 +88,111 @@ In order to train on your own data just provide the paths to your HDF5 training
One can monitor the training progress with Tensorboard `tensorboard --logdir <checkpoint_dir>/logs/` (you need `tensorflow` installed in your conda env), where `checkpoint_dir` is the path to the checkpoint directory specified in the config.

### Training tips

1. When training with binary-based losses, i.e.: `BCEWithLogitsLoss`, `DiceLoss`, `BCEDiceLoss`, `GeneralizedDiceLoss`:
The target data has to be 4D (one target binary mask per channel).
When training with `WeightedCrossEntropyLoss`, `CrossEntropyLoss`, `PixelWiseCrossEntropyLoss` the target dataset has to be 3D, see also pytorch documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html
2. `final_sigmoid` in the `model` config section applies only to the inference time (validation, test):
* When training with `BCEWithLogitsLoss`, `DiceLoss`, `BCEDiceLoss`, `GeneralizedDiceLoss` set `final_sigmoid=True`
* When training with cross entropy based losses (`WeightedCrossEntropyLoss`, `CrossEntropyLoss`, `PixelWiseCrossEntropyLoss`) set `final_sigmoid=False` so that `Softmax` normalization is applied to the output.
The target data has to be 4D (one target binary mask per channel).
When training with `WeightedCrossEntropyLoss`, `CrossEntropyLoss` the target dataset has to be 3D, see also pytorch
documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html
2. When training with `BCEWithLogitsLoss`, `DiceLoss`, `BCEDiceLoss`, `GeneralizedDiceLoss` set `final_sigmoid=True` in
the `model` part of the config so that the sigmoid is applied to the logits.
3. When training with cross entropy based losses (`WeightedCrossEntropyLoss`, `CrossEntropyLoss`) set
`final_sigmoid=False` so that `Softmax` normalization is applied to the logits.

## Prediction

Given that `pytorch-3dunet` package was installed via conda as described above, one can run the prediction via:

```
predict3dunet --config <CONFIG>
```

In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see example [test_config_segmentation.yaml](resources/3DUnet_confocal_boundary/test_config.yml)).

### Prediction tips
1. If you're running prediction for a large dataset, consider using `LazyHDF5Dataset` and `LazyPredictor` in the config. This will save memory by loading data on the fly at the cost of slower prediction time. See [test_config_lazy](resources/3DUnet_confocal_boundary/test_config_lazy.yml) for an example config.
2. If your model predicts multiple classes (see e.g. [train_config_multiclass](resources/3DUnet_multiclass/train_config.yaml)), consider saving only the final segmentation instead of the probability maps which can be time and space consuming.
To do so, set `save_segmentation: true` in the `predictor` section of the config (see [test_config_multiclass](resources/3DUnet_multiclass/test_config.yaml)).

1. If you're running prediction for a large dataset, consider using `LazyHDF5Dataset` and `LazyPredictor` in the config.
This will save memory by loading data on the fly at the cost of slower prediction time.
See [test_config_lazy](resources/3DUnet_confocal_boundary/test_config_lazy.yml) for an example config.
2. If your model predicts multiple classes (see
e.g. [train_config_multiclass](resources/3DUnet_multiclass/train_config.yaml)), consider saving only the final
segmentation instead of the probability maps which can be time and space consuming.
To do so, set `save_segmentation: true` in the `predictor` section of the config (
see [test_config_multiclass](resources/3DUnet_multiclass/test_config.yaml)).

## Data Parallelism
By default, if multiple GPUs are available training/prediction will be run on all the GPUs using [DataParallel](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html).
If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using `CUDA_VISIBLE_DEVICES`, e.g.

By default, if multiple GPUs are available training/prediction will be run on all the GPUs
using [DataParallel](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html).
If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using `CUDA_VISIBLE_DEVICES`,
e.g.

```bash
CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>
```

or

```bash
CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>
```

## Supported Loss Functions

### Semantic Segmentation

- `BCEWithLogitsLoss` (binary cross-entropy)
- `DiceLoss` (standard `DiceLoss` defined as `1 - DiceCoefficient` used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the `DiceLoss` per channel and averages the values)
- `BCEDiceLoss` (Linear combination of BCE and Dice losses, i.e. `alpha * BCE + beta * Dice`, `alpha, beta` can be specified in the `loss` section of the config)
- `CrossEntropyLoss` (one can specify class weights via the `weight: [w_1, ..., w_k]` in the `loss` section of the config)
- `PixelWiseCrossEntropyLoss` (one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth; `weight` dataset has to be provided in the H5 files for training and validation; see sample config in [train_config.yml](resources/3DUnet_confocal_boundary_weighted/train_config.yml)
- `DiceLoss` (standard `DiceLoss` defined as `1 - DiceCoefficient` used for binary semantic segmentation; when more than
2 classes are present in the ground truth, it computes the `DiceLoss` per channel and averages the values)
- `BCEDiceLoss` (Linear combination of BCE and Dice losses, i.e. `alpha * BCE + beta * Dice`, `alpha, beta` can be
specified in the `loss` section of the config)
- `CrossEntropyLoss` (one can specify class weights via the `weight: [w_1, ..., w_k]` in the `loss` section of the
config)
- `WeightedCrossEntropyLoss` (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation)
- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise, use standard `DiceLoss`.
- `GeneralizedDiceLoss` (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this
loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders
of magnitude more voxels than the others. Otherwise, use standard `DiceLoss`.

For a detailed explanation of some of the supported loss functions see:
[Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations](https://arxiv.org/pdf/1707.03237.pdf).

### Regression

- `MSELoss` (mean squared error loss)
- `L1Loss` (mean absolute error loss)
- `SmoothL1Loss` (less sensitive to outliers than MSELoss)
- `WeightedSmoothL1Loss` (extension of the `SmoothL1Loss` which allows to weight the voxel values above/below a given threshold differently)

- `WeightedSmoothL1Loss` (extension of the `SmoothL1Loss` which allows to weight the voxel values above/below a given
threshold differently)

## Supported Evaluation Metrics

### Semantic Segmentation

- `MeanIoU` (mean intersection over union)
- `DiceCoefficient` (computes per channel Dice Coefficient and returns the average)
If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics
(the metrics below are computed by running connected components on threshold boundary map and comparing the resulted instances to the ground truth instance segmentation):
- `BoundaryAveragePrecision` (Average Precision applied to the boundary probability maps: thresholds the output from the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth)
If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics
(the metrics below are computed by running connected components on threshold boundary map and comparing the resulted
instances to the ground truth instance segmentation):
- `BoundaryAveragePrecision` (Average Precision applied to the boundary probability maps: thresholds the output from the
network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the
ground truth)
- `AdaptedRandError` (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
- `AveragePrecision` (see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric)

If not specified `MeanIoU` will be used by default.


### Regression

- `PSNR` (peak signal to noise ratio)
- `MSE` (mean squared error)

## Examples

### Cell boundary predictions for lightsheet images of Arabidopsis thaliana lateral root

Training/predictions configs can be found in [3DUnet_lightsheet_boundary](resources/3DUnet_lightsheet_boundary).
Pre-trained model weights available [here](https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch).
In order to use the pre-trained model on your own data:

* download the `best_checkpoint.pytorch` from the above link
* add the path to the downloaded model and the path to your data in [test_config.yml](resources/3DUnet_lightsheet_boundary/test_config.yml)
* run `predict3dunet --config test_config.yml`
Expand All @@ -167,9 +209,11 @@ Sample z-slice predictions on the test set (top: raw input , bottom: boundary pr
<img src="https://github.com/wolny/pytorch-3dunet/blob/master/resources/3DUnet_lightsheet_boundary/root_movie1_t45_pred.png" width="400">

### Cell boundary predictions for confocal images of Arabidopsis thaliana ovules

Training/predictions configs can be found in [3DUnet_confocal_boundary](resources/3DUnet_confocal_boundary).
Pre-trained model weights available [here](https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FArabidopsis-Ovules%2Funet_bce_dice_ds2x&files=best_checkpoint.pytorch).
In order to use the pre-trained model on your own data:

* download the `best_checkpoint.pytorch` from the above link
* add the path to the downloaded model and the path to your data in [test_config.yml](resources/3DUnet_confocal_boundary/test_config.yml)
* run `predict3dunet --config test_config.yml`
Expand All @@ -186,6 +230,7 @@ Sample z-slice predictions on the test set (top: raw input , bottom: boundary pr
<img src="https://github.com/wolny/pytorch-3dunet/blob/master/resources/3DUnet_confocal_boundary/ovules_pred.png" width="400">

### Nuclei predictions for lightsheet images of Arabidopsis thaliana lateral root

Training/predictions configs can be found in [3DUnet_lightsheet_nuclei](resources/3DUnet_lightsheet_nuclei).
Pre-trained model weights available [here](https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_nuclei_ds1x&files=best_checkpoint.pytorch).
In order to use the pre-trained model on your own data:
Expand All @@ -203,6 +248,7 @@ Sample z-slice predictions on the test set (top: raw input, bottom: nuclei predi


### 2D nuclei predictions for Kaggle DSB2018

The data can be downloaded from: https://www.kaggle.com/c/data-science-bowl-2018/data

Training/predictions configs can be found in [2DUnet_dsb2018](resources/2DUnet_dsb2018).
Expand All @@ -213,10 +259,13 @@ Sample predictions on the test image (top: raw input, bottom: nuclei predictions
<img src="https://github.com/wolny/pytorch-3dunet/blob/master/resources/2DUnet_dsb2018/5f9d29d6388c700f35a3c29fa1b1ce0c1cba6667d05fdb70bd1e89004dcf71ed_predictions.png" width="400">

## Contribute

If you want to contribute back, please make a pull request.

## Cite

If you use this code for your research, please cite as:

```
@article {10.7554/eLife.57613,
article_type = {journal},
Expand Down
3 changes: 0 additions & 3 deletions pytorch3dunet/augment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,6 @@ def raw_transform(self):
def label_transform(self):
return self._create_transform('label')

def weight_transform(self):
return self._create_transform('weight')

@staticmethod
def _transformer_class(class_name):
m = importlib.import_module('pytorch3dunet.augment.transforms')
Expand Down
Loading

0 comments on commit dd8248d

Please sign in to comment.