Skip to content

Commit

Permalink
Merge pull request #640 from RasmusOrsoe/save-bestfit-model-from-earl…
Browse files Browse the repository at this point in the history
…y-stopping

`StandardModel.fit`: Automatically load in best-fit weights from early stopping when training is finished.
  • Loading branch information
RasmusOrsoe authored Dec 3, 2023
2 parents 800ebd9 + 3c50102 commit f0d242d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 52 deletions.
14 changes: 2 additions & 12 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
from typing import Any, Dict, List, Optional

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam
Expand All @@ -14,9 +13,8 @@
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.gnn import DynEdge
from graphnet.models.graphs import KNNGraph
from graphnet.models.graphs.nodes import NodesAsPulses
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR
from graphnet.training.callbacks import PiecewiseLinearLR
from graphnet.training.loss_functions import LogCoshLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
Expand Down Expand Up @@ -130,18 +128,10 @@ def main(
)

# Training model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config["early_stopping_patience"],
),
ProgressBar(),
]

model.fit(
training_dataloader,
validation_dataloader,
callbacks=callbacks,
early_stopping_patience=config["early_stopping_patience"],
logger=wandb_logger if wandb else None,
**config["fit"],
)
Expand Down
11 changes: 1 addition & 10 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
from typing import Any, Dict, List, Optional

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
Expand All @@ -18,7 +17,6 @@
DirectionReconstructionWithKappa,
)
from graphnet.training.labels import Direction
from graphnet.training.callbacks import ProgressBar
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
Expand Down Expand Up @@ -133,18 +131,11 @@ def main(
)

# Training model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config["early_stopping_patience"],
),
ProgressBar(),
]

model.fit(
training_dataloader,
validation_dataloader,
callbacks=callbacks,
early_stopping_patience=config["early_stopping_patience"],
logger=wandb_logger if wandb else None,
**config["fit"],
)
Expand Down
12 changes: 1 addition & 11 deletions examples/04_training/03_train_dynedge_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from typing import List, Optional
import os

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from graphnet.constants import EXAMPLE_OUTPUT_DIR
from graphnet.data.dataloader import DataLoader
from graphnet.models import StandardModel
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.config import (
DatasetConfig,
Expand Down Expand Up @@ -86,18 +84,10 @@ def main(
wandb_logger.experiment.config.update(dataset_config.as_dict())

# Train model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config.early_stopping_patience,
),
ProgressBar(),
]

model.fit(
dataloaders["train"],
dataloaders["validation"],
callbacks=callbacks,
early_stopping_patience=config.early_stopping_patience,
logger=wandb_logger if wandb else None,
**config.fit,
)
Expand Down
12 changes: 1 addition & 11 deletions examples/04_training/04_train_multiclassifier_from_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
from typing import List, Optional, Dict, Any

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from graphnet.data.dataset.dataset import EnsembleDataset
Expand All @@ -15,7 +14,6 @@
from graphnet.data.dataloader import DataLoader
from graphnet.data.dataset import Dataset
from graphnet.models import StandardModel
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.config import (
DatasetConfig,
Expand Down Expand Up @@ -112,18 +110,10 @@ def main(
wandb_logger.experiment.config.update(dataset_config.as_dict())

# Training model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config.early_stopping_patience,
),
ProgressBar(),
]

model.fit(
train_dataloaders,
valid_dataloaders,
callbacks=callbacks,
early_stopping_patience=config.early_stopping_patience,
logger=wandb_logger if wandb else None,
**config.fit,
)
Expand Down
84 changes: 76 additions & 8 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
Expand Down Expand Up @@ -103,6 +103,7 @@ def fit(
val_dataloader: Optional[DataLoader] = None,
*,
max_epochs: int = 10,
early_stopping_patience: int = 5,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
Expand All @@ -115,12 +116,25 @@ def fit(
"""Fit `StandardModel` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
# We create the bare-minimum callbacks for you.
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
early_stopping_patience=early_stopping_patience,
)
elif val_dataloader is not None:
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
self.debug("No Callbacks specified. Default callbacks added.")
else:
# You are on your own!
self.debug("Initializing training with user-provided callbacks.")
pass
self._print_callbacks(callbacks)
has_early_stopping = self._contains_callback(callbacks, EarlyStopping)
has_model_checkpoint = self._contains_callback(
callbacks, ModelCheckpoint
)

if (has_early_stopping) & (has_model_checkpoint is False):
self.warning(
"""No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!"""
)

self.train(mode=True)
Expand All @@ -143,6 +157,33 @@ def fit(
self.warning("[ctrl+c] Exiting gracefully.")
pass

# Load weights from best-fit model after training if possible
if has_early_stopping & has_model_checkpoint:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
checkpoint_callback = callback
self.load_state_dict(
torch.load(checkpoint_callback.best_model_path)["state_dict"]
)
self.info("Best-fit weights from EarlyStopping loaded.")

def _print_callbacks(self, callbacks: List[Callback]) -> None:
callback_names = []
for cbck in callbacks:
callback_names.append(cbck.__class__.__name__)
self.info(
f"Training initiated with callbacks: {', '.join(callback_names)}"
)

def _contains_callback(
self, callbacks: List[Callback], callback: Callback
) -> bool:
"""Check if `callback` is in `callbacks`."""
for cbck in callbacks:
if isinstance(cbck, callback):
return True
return False

@property
def target_labels(self) -> List[str]:
"""Return target label."""
Expand Down Expand Up @@ -401,11 +442,38 @@ def predict_as_dataframe(
)
return results

def _create_default_callbacks(self, val_dataloader: DataLoader) -> List:
def _create_default_callbacks(
self,
val_dataloader: DataLoader,
early_stopping_patience: Optional[int] = None,
) -> List:
"""Create default callbacks.
Used in cases where no callbacks are specified by the user in .fit
"""
callbacks = [ProgressBar()]
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
)
if val_dataloader is not None:
assert early_stopping_patience is not None
# Add Early Stopping
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
)
)
# Add Model Check Point
callbacks.append(
ModelCheckpoint(
save_top_k=1,
monitor="val_loss",
mode="min",
filename=f"{self._gnn.__class__.__name__}"
+ "-{epoch}-{val_loss:.2f}-{train_loss:.2f}",
)
)
self.info(
f"EarlyStopping has been added with a patience of {early_stopping_patience}."
)
return callbacks

def _add_early_stopping(
Expand Down

0 comments on commit f0d242d

Please sign in to comment.