Skip to content

PyTorch and TensorFlow-Keras Training - 🧠🛠️ Utilizes CIFAR-10 dataset for PyTorch and MNIST for TensorFlow-Keras. Implements early-stopping to prevent overfitting during training. Provides code snippets for early-stopping implementation in both PyTorch and TensorFlow-Keras.

License

Notifications You must be signed in to change notification settings

deBUGger404/Earlystoping-torch-keras

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Earlystoping Torch/Tensorflow-Keras

Contents

Introduction

  • In PyTorch training, I used the CIFAR-10 dataset which consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
  • In TensorFlow-Keras training, I used MNIST data which consists of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images.

Prerequisites

  • Python>=3.6
  • PyTorch >=1.4
  • Tensorflow>=2.0
  • Library are mentioned in requirenments.txt

Training

  • If We train our model with too many epochs then it will start overfitting on the training dataset and showing worse performance on the test dataset. And vice versa, too few epochs can lead the model to underfit on trainset.
  • So, Early-stopping is an approach to stop training after some epoch if there is no significant improvement in performance.
  • Basically, Early-stopping monitors the performance during the training using TensorFlow-Keras API.

    Pytorch:

    Below codes, use for early-stopping in PyTorch to overcome the model overfitting.
    #use the ./pytorch/utils.py
    #as you can see the below code which will monitor the performance during the training
        elif score < self.best_score + self.delta:
          self.counter += 1
          self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
          if self.counter >= self.patience:
              self.early_stop = True
    # below code save model checkpoint while model will not imporving anymore
        def save_checkpoint(self, val_loss, model):
          if self.verbose:
              self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
          torch.save(model, self.path)
          self.val_loss_min = val_loss

    Tensorflow-Keras

    Below code, use for early-stoping for better performance in Keras.
    import tensorflow as tf
    es = tf.keras.callbacks.EarlyStopping( monitor="val_loss", patience=2, verbose=1, restore_best_weights=True)

Python-Script

 # Start training with: 
 # pytorch
 python earlystoping_pytorch.py

 #tensorflow-keras
 python keras_early_stoping.py

Give a ⭐ to this Repository!

About

PyTorch and TensorFlow-Keras Training - 🧠🛠️ Utilizes CIFAR-10 dataset for PyTorch and MNIST for TensorFlow-Keras. Implements early-stopping to prevent overfitting during training. Provides code snippets for early-stopping implementation in both PyTorch and TensorFlow-Keras.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published