Skip to content

Commit

Permalink
[training](feat) Add IncreaseAutoregressionLengthOnPlateau callback f…
Browse files Browse the repository at this point in the history
…or AR training
  • Loading branch information
kzajac97 committed Apr 4, 2024
1 parent 0fa6ce5 commit 8ca6d95
Showing 1 changed file with 87 additions and 6 deletions.
93 changes: 87 additions & 6 deletions pydentification/training/lightning/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from bisect import bisect_right
from collections import Counter
from typing import Any, Sequence
from typing import Any, Literal, Sequence

import lightning.pytorch as pl

Expand Down Expand Up @@ -53,20 +53,20 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:

def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
if self.verbose:
print(f"StepAutoRegressionLengthScheduler: initial length = {trainer.datamodule.n_forward_time_steps}")
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")

self.base_length = trainer.datamodule.n_forward_time_steps

def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:
if self.base_length is None:
raise RuntimeError("StepAutoRegressionLengthScheduler: base_length is None")
raise RuntimeError("{self.__class__.__name__}: base_length is None!")

if trainer.current_epoch % self.step_size == 0:
trainer.datamodule.n_forward_time_steps = self._get_closed_form_ar_length(trainer.current_epoch)

if self.verbose:
print(
f"StepAutoRegressionLengthScheduler: new length = {trainer.datamodule.n_forward_time_steps}"
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
f" at epoch {trainer.current_epoch}"
)

Expand Down Expand Up @@ -100,7 +100,7 @@ def _get_closed_form_ar_length(self, epoch: int) -> int:

def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
if self.verbose:
print(f"MultiStepAutoRegressionLengthScheduler: initial length = {trainer.datamodule.n_forward_time_steps}")
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")

self.base_length = trainer.datamodule.n_forward_time_steps

Expand All @@ -112,6 +112,87 @@ def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:

if self.verbose:
print(
f"MultiStepAutoRegressionLengthScheduler: new length = {trainer.datamodule.n_forward_time_steps}"
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
f" at epoch {trainer.current_epoch} with milestones {list(self.milestones.keys())}"
)


class IncreaseAutoRegressionLengthOnPlateau(AbstractAutoRegressionLengthScheduler):
"""
Increases the length of auto-regression by factor once the monitored quantity stops improving.
Works as ReduceLROnPlateau scheduler, but increasing the length (given as int!) instead of decaying learning rate.
Source reference: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
"""

def __init__(
self,
monitor: str,
patience: int,
factor: int,
threshold: float = 1e-4,
threshold_mode: Literal["abs", "rel"] = "rel",
max_length: int | None = None,
verbose: bool = False,
):
"""
:param monitor: quantity to be monitored given as key from callback_metrics dictionary of pl.Trainer
:param patience: number of epochs with no improvement after which auto-regression length will be increased
:param factor: factor by which to increase auto-regression length. new_length = old_length * factor
:param threshold: threshold for measuring the new optimum, to only focus on significant changes
:param threshold_mode: one of {"rel", "abs"}, defaults to "rel"
:param max_length: maximum auto-regression length, defaults to None
:param verbose: if True, prints the auto-regression length when it is changed
"""
super().__init__()

self.monitor = monitor
self.patience = patience
self.factor = factor

self.threshold = threshold
self.threshold_mode = threshold_mode
self.max_length = max_length
self.verbose = verbose

self.best = float("inf")
self.num_bad_epochs = 0

def on_train_start(self, trainer: pl.Trainer, _: Any) -> None:
if self.verbose:
print(f"{self.__class__.__name__}: initial length = {trainer.datamodule.n_forward_time_steps}")

def is_better(self, current: float, best: float) -> bool:
if self.threshold_mode == "rel":
return current < best * (float(1) - self.threshold)

else: # self.threshold_mode == "abs":
return current < best - self.threshold

def on_train_epoch_start(self, trainer: pl.Trainer, _: Any) -> None:
current = trainer.callback_metrics.get(self.monitor)
if current is None:
raise RuntimeError(f"{self.__class__.__name__}: metric {self.monitor} not found in callback_metrics!")

if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.num_bad_epochs >= self.patience:
new_length = trainer.datamodule.n_forward_time_steps * self.factor

if new_length > self.max_length:
if self.verbose:
print(f"{self.__class__.__name__}: maximum length reached, not increasing")
return # exit function is new length is greater than maximum length

trainer.datamodule.n_forward_time_steps = new_length
self.num_bad_epochs = 0

if self.verbose:
print(
f"{self.__class__.__name__}: new length = {trainer.datamodule.n_forward_time_steps}"
f" at epoch {trainer.current_epoch}"
)

0 comments on commit 8ca6d95

Please sign in to comment.