Repository of the paper Sauron U-Net: Simple automated redundancy elimination in medical image segmentation via filter pruning
- 1. Using Sauron
- 1.1 Prerequisites and data
- 1.2 Training
- 1.3 Computing the output
- 1.4 Evaluation
- 2. Experiments
- 2.1 Section 4.1: Benchmark
- 2.2 Section 4.2: Clusterability
- 2.3 Section 4.3: Feature maps interpretation
Libraries: We utilized Pytorch 1.7.1 and TorchIO 0.18.71. Datasets: We utilized two publicly-available datasets: ACDC and KiTS The exact train-val-test splits and data augmentation parameters are specified in the code (lib/data/*).
By specifying the dataset name, we can train a nnUNet model from scratch.
python --data datasetName
The remaining parameters, such as the number of epochs, dataset splits and seeds, can be found in the function parseArguments.
To continue training a model:
python --data datasetName --epochs 500 --seed_train 42 --seed_data 42 --split 0.9:0.1 --epoch_start 400 --model_state path/model-400 --history path/to/previous/run
Particularly, --history expects the path that contains the pruned filters and other essential files that were saved during training.
To generate the segmentations from a Sauron-pruned nnUNet model:
python --data datasetName --output output_path/predictions --model_state path/model-500 --original /path/to/dataset --in_filters path/in_filters --out_filters path/out_filters
--original expects the path to the original files of the dataset. This is important to guarantee that the segmentations will have the same voxel resolution.
python --data datasetName --pred path/predictions --gt path/ground_truth --output output_path/results.json
- Sauron was run following the steps described above.
- For Sauron (
) 2.1 lib/loss: Set
in DS_CrossEntropyDiceLoss_Distance 2.2 leave callback._end_epoch_prune 2.3 lib/data/XXXDataset: dist_fun = "euc_norm"; imp_fun = "euc_rand"
- For nnUNet 3.1 remove callback._end_epoch_prune 3.2 model = nnUNet(**cfg["architecture"]) 3.3 lib/data/XXXDataset: dist_fun = ""; imp_fun = "" 3.4 lib/data/XXXDataset: use DS_CrossEntropyDiceLoss instead of DS_CrossEntropyDiceLoss_Distance
Store the feature maps by adding to the callback _end_epoch_save_all_FMs and remove _end_epoch_prune to avoid pruning.
Load the trained models
model = Sauron(**params)
model.initialize(device="cuda", model_state=path, log=log, isSauron=True)
test_data = data.get("test")
with torch.no_grad():
for sub_i, subject in enumerate(test_data):
FMs = model.forward_saveFMs(image)