diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index 297159dfd..80b0a8438 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -3,7 +3,6 @@ import os from typing import Any, Dict, List, Optional -from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger import torch from torch.optim.adam import Adam @@ -14,9 +13,8 @@ from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdge from graphnet.models.graphs import KNNGraph -from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.task.reconstruction import EnergyReconstruction -from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR +from graphnet.training.callbacks import PiecewiseLinearLR from graphnet.training.loss_functions import LogCoshLoss from graphnet.training.utils import make_train_validation_dataloader from graphnet.utilities.argparse import ArgumentParser @@ -130,18 +128,10 @@ def main( ) # Training model - callbacks = [ - EarlyStopping( - monitor="val_loss", - patience=config["early_stopping_patience"], - ), - ProgressBar(), - ] - model.fit( training_dataloader, validation_dataloader, - callbacks=callbacks, + early_stopping_patience=config["early_stopping_patience"], logger=wandb_logger if wandb else None, **config["fit"], ) diff --git a/examples/04_training/02_train_tito_model.py b/examples/04_training/02_train_tito_model.py index 9aab97762..cfe87ffff 100644 --- a/examples/04_training/02_train_tito_model.py +++ b/examples/04_training/02_train_tito_model.py @@ -3,7 +3,6 @@ import os from typing import Any, Dict, List, Optional -from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger from torch.optim.adam import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -18,7 +17,6 @@ DirectionReconstructionWithKappa, ) from graphnet.training.labels import Direction -from graphnet.training.callbacks import ProgressBar from graphnet.training.loss_functions import VonMisesFisher3DLoss from graphnet.training.utils import make_train_validation_dataloader from graphnet.utilities.argparse import ArgumentParser @@ -133,18 +131,11 @@ def main( ) # Training model - callbacks = [ - EarlyStopping( - monitor="val_loss", - patience=config["early_stopping_patience"], - ), - ProgressBar(), - ] model.fit( training_dataloader, validation_dataloader, - callbacks=callbacks, + early_stopping_patience=config["early_stopping_patience"], logger=wandb_logger if wandb else None, **config["fit"], ) diff --git a/examples/04_training/03_train_dynedge_from_config.py b/examples/04_training/03_train_dynedge_from_config.py index fb22daf30..e7603de62 100644 --- a/examples/04_training/03_train_dynedge_from_config.py +++ b/examples/04_training/03_train_dynedge_from_config.py @@ -3,13 +3,11 @@ from typing import List, Optional import os -from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only from graphnet.constants import EXAMPLE_OUTPUT_DIR from graphnet.data.dataloader import DataLoader from graphnet.models import StandardModel -from graphnet.training.callbacks import ProgressBar from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.config import ( DatasetConfig, @@ -86,18 +84,10 @@ def main( wandb_logger.experiment.config.update(dataset_config.as_dict()) # Train model - callbacks = [ - EarlyStopping( - monitor="val_loss", - patience=config.early_stopping_patience, - ), - ProgressBar(), - ] - model.fit( dataloaders["train"], dataloaders["validation"], - callbacks=callbacks, + early_stopping_patience=config.early_stopping_patience, logger=wandb_logger if wandb else None, **config.fit, ) diff --git a/examples/04_training/04_train_multiclassifier_from_configs.py b/examples/04_training/04_train_multiclassifier_from_configs.py index d9c44e1ff..a7e34e7e4 100644 --- a/examples/04_training/04_train_multiclassifier_from_configs.py +++ b/examples/04_training/04_train_multiclassifier_from_configs.py @@ -3,7 +3,6 @@ import os from typing import List, Optional, Dict, Any -from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only from graphnet.data.dataset.dataset import EnsembleDataset @@ -15,7 +14,6 @@ from graphnet.data.dataloader import DataLoader from graphnet.data.dataset import Dataset from graphnet.models import StandardModel -from graphnet.training.callbacks import ProgressBar from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.config import ( DatasetConfig, @@ -112,18 +110,10 @@ def main( wandb_logger.experiment.config.update(dataset_config.as_dict()) # Training model - callbacks = [ - EarlyStopping( - monitor="val_loss", - patience=config.early_stopping_patience, - ), - ProgressBar(), - ] - model.fit( train_dataloaders, valid_dataloaders, - callbacks=callbacks, + early_stopping_patience=config.early_stopping_patience, logger=wandb_logger if wandb else None, **config.fit, ) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index b31cce443..40a9b7ce0 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -5,7 +5,7 @@ import numpy as np import torch from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch import Tensor from torch.nn import ModuleList from torch.optim import Adam @@ -103,6 +103,7 @@ def fit( val_dataloader: Optional[DataLoader] = None, *, max_epochs: int = 10, + early_stopping_patience: int = 5, gpus: Optional[Union[List[int], int]] = None, callbacks: Optional[List[Callback]] = None, ckpt_path: Optional[str] = None, @@ -115,12 +116,25 @@ def fit( """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" # Checks if callbacks is None: + # We create the bare-minimum callbacks for you. callbacks = self._create_default_callbacks( val_dataloader=val_dataloader, + early_stopping_patience=early_stopping_patience, ) - elif val_dataloader is not None: - callbacks = self._add_early_stopping( - val_dataloader=val_dataloader, callbacks=callbacks + self.debug("No Callbacks specified. Default callbacks added.") + else: + # You are on your own! + self.debug("Initializing training with user-provided callbacks.") + pass + self._print_callbacks(callbacks) + has_early_stopping = self._contains_callback(callbacks, EarlyStopping) + has_model_checkpoint = self._contains_callback( + callbacks, ModelCheckpoint + ) + + if (has_early_stopping) & (has_model_checkpoint is False): + self.warning( + """No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!""" ) self.train(mode=True) @@ -143,6 +157,33 @@ def fit( self.warning("[ctrl+c] Exiting gracefully.") pass + # Load weights from best-fit model after training if possible + if has_early_stopping & has_model_checkpoint: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + checkpoint_callback = callback + self.load_state_dict( + torch.load(checkpoint_callback.best_model_path)["state_dict"] + ) + self.info("Best-fit weights from EarlyStopping loaded.") + + def _print_callbacks(self, callbacks: List[Callback]) -> None: + callback_names = [] + for cbck in callbacks: + callback_names.append(cbck.__class__.__name__) + self.info( + f"Training initiated with callbacks: {', '.join(callback_names)}" + ) + + def _contains_callback( + self, callbacks: List[Callback], callback: Callback + ) -> bool: + """Check if `callback` is in `callbacks`.""" + for cbck in callbacks: + if isinstance(cbck, callback): + return True + return False + @property def target_labels(self) -> List[str]: """Return target label.""" @@ -401,11 +442,38 @@ def predict_as_dataframe( ) return results - def _create_default_callbacks(self, val_dataloader: DataLoader) -> List: + def _create_default_callbacks( + self, + val_dataloader: DataLoader, + early_stopping_patience: Optional[int] = None, + ) -> List: + """Create default callbacks. + + Used in cases where no callbacks are specified by the user in .fit + """ callbacks = [ProgressBar()] - callbacks = self._add_early_stopping( - val_dataloader=val_dataloader, callbacks=callbacks - ) + if val_dataloader is not None: + assert early_stopping_patience is not None + # Add Early Stopping + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=early_stopping_patience, + ) + ) + # Add Model Check Point + callbacks.append( + ModelCheckpoint( + save_top_k=1, + monitor="val_loss", + mode="min", + filename=f"{self._gnn.__class__.__name__}" + + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", + ) + ) + self.info( + f"EarlyStopping has been added with a patience of {early_stopping_patience}." + ) return callbacks def _add_early_stopping(