This repository contains a class to permorm data augmentation on 3D objects (e.g. 3D medical images). It is a 3D (..well 4D with the number of channels included 😅) version of the 2D "tf.keras.preprocessing.image.ImageDataGenerator". We also provide two examples as a simple guideline.
Data augmentation is a regularization technique that has been found extremely usefull when training CNNs. It is a techinique that prevents the model of seeing the original training and validation data during training, and instead applies some transofrmations on the original training data (or batches) and lets the model see those instead. Data augmentation is a mean to reduce overfitting and make a more robust model.
- "Stefanos Karageorgiou" [email protected]
- "Ani Ajdini" [email protected]
While doing my thesis this summer (Stefanos), I realized that the tensorflow resources on 3D image model training are limited (almost none pretrained models, limited regularization techniques etc). 3D data are sometimes hard to find, especially medical and they are often not many in number. Hence, we really believe that data augmentation can have a huge impact on the overfitting prevention.
This repository is a quarantine project (yes, we were bored 😌) created to help the few other crazies working on similar projects.
The Image3DGenerator class despite its name is actually a "tf.keras.utils.Sequence" object, or in other words a base object for fitting to a sequence of data, such as a dataset. Sequence are a safer way to do multiprocessing as this structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.
This class applies random transformations to the original training and validation data which change during each epoch.
The options we provide (yet) are the following:
- Generation of batches without any transformation
- Rotation: Randomly rotates the whole object to a range of angles drawn from a normal distribution with 0 mean and variance specified by the user.
- Gaussian noise: Adds random noise to the 3D objects drawn from a normal distribution with 0 mean and variance specified by the user.
- Normalization: Applies a min max scaler transofrmation to the objects which bounds the voxel values between 0 and 1.
In order to use this class your data and folders should be structured as follows:
----data-folder/data.npy
--Image3DGenerator.py
--your_python_script
Notes:
⚡ The data folder should contain each 3D object seperately, each in a numpy array form (.npy)
⚡ Each 3D object should have the following dimension order: (object_length, object_height, object_width, number_of_channels (if grayscaled can be skipped)).
To use the class you need to do the following steps:
- Create a dictionary containing the ID of the training (and validation examples if applicable).
- Create a dictionary containing all the training (and validation) IDs along with their classes. The classes should be integers starting with 0.
After having all the prerequirements ready you simply type the following:
from Image3DGenerator import DataGenerator
params = {
'dim': your object's dimensions,
'batch_size': opted batch size,
'n_classes': number of your classes,
'n_channels': 1 if grayscale, 3 if RGB,
'rotation': True in case you want to apply random roation during training,
'normalisation': True,
'min_bound': in case normalisation is True, specify the minimum voxel value of your objects,
'max_bound': in case normalisation is True, specify the maximum voxel value of your objects,
'gaussian_noise': True,
'noise_mean': 0,
'noise_std': 0.01,
'shuffle': True,
'rotate_std':45,
'path':'./data' #path of the folder containing the data,
'display_ID':False}
# Generators
training_generator = DataGenerator(dictionary['train'], labels, **params)
validation_generator = DataGenerator(dictionary['validation'], labels, **params)
#After creating and compliling your tf model
model.fit(x = training_generator,
epochs= no_epochs,
validation_data= validation_generator)
Two examples with codes and outputs are available at the examples folder. Below you will find visual examples with the intention to help the user understand how the class treats the data during training.
1) Grayscale CT scan
Transformations applied: Rotation, Noise, Normalisation
GIF
Original | Transformed |
---|---|
3D
2) RGB gif
Transformations applied: Rotation, Noise
Original | Transformed |
---|---|
3) RGB gif
Transformations applied: Rotation
Original | Transformed |
---|---|
4) RGB gif
Transformations applied: Noise
Original | Transformed |
---|---|
This is a quarantine project developed by two recent data science graduates so there is undeniably room for improvement. Pull requests are more than welcome. We would be glad to hear your feedback and have a chat. For major changes, please open an issue first to discuss what you would like to change.