This is the repository for our domain generalization project we conducted in winter term 2024/25 with the chair for Explainable Artificial Intelligence at Otto-Friedrich-University Bamberg. The goal was developing a plug-and-play domain generalization pipeline to test different data augmentation strategies.
We provide four modules:
domgen.augment
: Augmentation strategies and handler classes.domgen.model_training
: Model definitions, utility functions and a trainer class tying it all together.domgen.tuning
: Classes for hyperparameter and augmentation tuning.domgen.data
: Definition of a dataset class to handle specific domain datasets.
1. Setup
From root
run:
python -m pip install -e .
⚡ Currently there is an issue with git lfs
where cloning fails for some users (git-lfs/git-lfs#5749). This can be resolved by setting GIT_CLONE_PROTECTION_ACTIVE=false
. This is but a temporary fix. ⚡
⚡ Due to the medmnistc
dependency, we also require wand
to be installed ⚡
2. Getting the Data
Run download_data.py
with either the --download_pacs
, --download_camelyon17
, or --download_all
flag.
python domgen\data\download_data.py --download_pacs
3. Runnig Experiments
For a single configuration, run
python .\train.py --config "<path-to-config.yaml>"
or run the following to start an entire suite (e.g., training the mixstyle models):
.\assets\scripts\run_training_suite.ps1 "PACS\5-mixstyle"
You find all configurations we used for our experiments in assets\config
.
If you happen to find bugs - as we're sure there are unspotted ones - please do let us know!:)
If you want to use the trainer independently of train.py
you can do so by simply instatiating an instance of the DomGenTrainer
, and passing a configuration by either:
- defining a
config.yaml
file and passing it via the--config
flag when runningtrain.py
or - defining a
dataclass
or similar and passing it to the trainer.
from domgen.model_training import DomGenTrainer
args = ... # args must be a namespace object / dataclass
trainer = DomGenTrainer(args)
trainer.fit()
After training, the results of the experiments can be saved by calling
trainer.save_metrics(trainer.metrics, experiment_path)
trainer.save_config(f'{experiment_path}/trainer_config.json')
which save the results of the training to file.
⚡ Note, that saving the trainer config requires it to be serializable. This needs to be handled individually. ⚡
Additionally, we provide functions to plot the results:
from domgen.utils import plot_accuracies, plot_training_curves
# experiment_path is the path to the directory where
# the results of the run were saved to
experiment_path = f"{args.log_dir}/{args.experiment}"
plot_accuracies(root_path=experiment_path, save=True, show=False)
plot_training_curves(base_dir=experiment_path, show=False)
We provide a range of augmentation strategies:
no_augment
: This is the identity transformation.mixup
: Using the implementation of PyTorch here.cutmix
: Using the implementation of PyTorch here.augmix
: Using the implementation of PyTorch here.randaugment
: Using the implementation of PyTorch here.mixstyle
: Using our own implementation based on Zhou et al. 2021.pacs_custom
: A set of handcrafted image-level augmentations.medmnistc
: We provide a wrapper for di Salvo et al. 2024.
For both pacs_custom
and medmnistc
you need to specify an additional field in the config aug_dict
, that specifies which of the augmentations you want to use.
You can extend this by adding new augmentation strategies.
We provide configuration files for our models with the hyperparameters that worked best. Alternatively, you can run your own trials. To run a trial run for one model-config simply run:
.\assets\config\scripts\run_trials.ps1 "<base-model-dir>" "<tune-config.yaml>"
Additionally, we provide scripts for tuning all models of a certain class as well as running a full sweep on all configs.
We ran our experiments on two common domain generalization benchmarks: PACS
and Camelyon17
.
For convenience, we provide a utility script, that downloads both datasets and prepares the necessary directory structure.
To use the script, run download_data.py
:
python download_data.py --all
This loads both datasets into the datasets
directory. To get only of the sets, simply specify either --download_pacs
or --download_camelyon17
.
The download directory can be changed via the --datadir
flag.
Alternatively, you can manually load the datasets.
The PACS dataset is a benchmark for domain generalization tasks. It consists of images from four domains: Photo, Art painting, Cartoon, and Sketch. The dataset includes seven object classes: Dog, Elephant, Giraffe, Guitar, Horse, House, and Person, making it suitable for evaluating models across diverse styles and distributions.
Fig.2 - Snippet from the PACS dataset. The images are 224x224 pixels and span seven classes over four domains.@misc{li2017deeperbroaderartierdomain,
title={Deeper, Broader and Artier Domain Generalization},
author={Da Li and Yongxin Yang and Yi-Zhe Song and Timothy M. Hospedales},
year={2017},
eprint={1710.03077},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/1710.03077},
}
The Camelyon17 dataset is a benchmark for evaluating domain generalization in histopathology image analysis. It contains images from five medical centers in the Netherlands, capturing inter-center variability (i.e. introduced by different scanners). The dataset is designed for tumor classification tasks and assessing the robustness of machine learning models in healthcare applications.
We use the patch-based variant of the dataset (Bandhi 2018). The data can be accessed via the
WILDS
ecosystem or downloaded directly from
here.
@article{bandi2018detection,
title={From detection of individual metastases to classification of lymph node status at the patient level: the CAMELYON17 challenge},
author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others},
journal={IEEE Transactions on Medical Imaging},
year={2018},
publisher={IEEE}
}
You can easily extend this repository by adding additional datasets. For that, you need to make sure that the dataset is placed in the datasets
directory
and follows the structure as shown (subdirectories for each domain containing the class directories):
├── dataset_root
│ ├── domain1
│ │ ├── cls1
│ │ ├── cls2
│ │ ├── ...
│ ├── domain2
│ │ ├── cls1
│ │ ├── cls2
│ │ ├── ...
...
To use this repository, you need the following libraries:
- jupyter
- torch~=2.5.0
- torchvision~=0.20.0
- tqdm~=4.67.0
- tensorboard
- pandas~=2.2.3
- wilds~=2.0.0
- gdown~=5.2.0
- ray~=2.39.0
- matplotlib~=3.9.2
- numpy<2
- PyYAML~=6.0.2
- albumentations~=1.4.22
- scikit-learn~=1.5.2
- pillow~=11.0.0
- scipy~=1.14.1
- typing_extensions~=4.12.2
- opencv-python~=4.10.0.84
- medmnistc
- wand > 0.6.10