From 973349f0afe21a6cf17f02b3bd8821426068b38f Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 15:29:45 +0000 Subject: [PATCH 1/6] Rollout Schedulers --- .../training/diagnostics/mlflow/logger.py | 6 +- .../training/schedulers/rollout/__init__.py | 167 ++++++++ .../training/schedulers/rollout/indexed.py | 172 +++++++++ .../training/schedulers/rollout/randomise.py | 364 ++++++++++++++++++ .../training/schedulers/rollout/stepped.py | 155 ++++++++ src/anemoi/training/train/forecaster.py | 27 +- src/anemoi/training/train/train.py | 2 +- tests/schedulers/__init__.py | 8 + tests/schedulers/rollout/__init__.py | 8 + 9 files changed, 891 insertions(+), 18 deletions(-) create mode 100644 src/anemoi/training/schedulers/rollout/__init__.py create mode 100644 src/anemoi/training/schedulers/rollout/indexed.py create mode 100644 src/anemoi/training/schedulers/rollout/randomise.py create mode 100644 src/anemoi/training/schedulers/rollout/stepped.py create mode 100644 tests/schedulers/__init__.py create mode 100644 tests/schedulers/rollout/__init__.py diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 03a4b6de..0f6deeb8 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -196,12 +196,12 @@ def _log_collector(self) -> None: log_capture_time_counter = 0 def _store_buffered_logs(self) -> None: - _buffer_size = self._io_buffer.tell() - if not _buffer_size: + buffer_size = self._io_buffer.tell() + if not buffer_size: return self._io_buffer.seek(0) # read and reset the buffer - data = self._io_buffer.read(_buffer_size) + data = self._io_buffer.read(buffer_size) self._io_buffer.seek(0) # handle the buffered data and store # split lines and keep \n at the end of each line diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py new file mode 100644 index 00000000..9da65b2e --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -0,0 +1,167 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import Literal + + +class RolloutScheduler(ABC): + """ + `RolloutScheduler` is an abstract base class for rollout schedulers. + + A rollout scheduler is an object that manages the rollout of a training loop. + + ```python + RollSched = RolloutScheduler() + + for epoch in range(20): + for step in range(100): + y = model(x, rollout = RollSched.rollout) + + RollSched.step() + RollSched.step_epoch() + ``` + """ + + _epoch: int = 0 + _step: int = 0 + + @property + @abstractmethod + def rollout(self) -> int: + """Get the current rollout value.""" + error_msg = "`rollout` property not implemented by parent class." + raise NotImplementedError(error_msg) + + @property + @abstractmethod + def maximum_rollout(self) -> int: + """Get maximum rollout possible.""" + error_msg = "`maximum_rollout` property not implemented by parent class." + raise NotImplementedError(error_msg) + + @property + def current_maximum(self) -> int: + """Get the current maximum rollout value.""" + return self.rollout + + def __int__(self) -> int: + return int(self.rollout) + + def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int: + """ + Get the rollout at a specific step and epoch. + + Parameters + ---------- + step : int, optional + Step value to override with, by default None + epoch : int, optional + Epoch value to override with, by default None + + Returns + ------- + int + Rollout value at the specified step and epoch. + """ + step_ = self._step + epoch_ = self._epoch + + self._step = step if step is not None else step_ + self._epoch = epoch if epoch is not None else epoch_ + + rollout = self.rollout + + self._step = step_ + self._epoch = epoch_ + + return rollout + + def step(self, count: int = 1, /) -> None: + """Step the scheduler by a count.""" + self._step += count + + def step_epoch(self, count: int = 1, /) -> None: + """Step the scheduler by a count of epochs.""" + self._epoch += count + + def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: + """ + Get the count of steps or epochs. + + Parameters + ---------- + every_n : int + Every n steps or epochs. + step_type : _type_, optional + Which to count, by default Literal['step', 'epoch'] + + Returns + ------- + int + Count of steps or epochs. + + Raises + ------ + ValueError + If the step_type is not 'step' or 'epoch'. + """ + if step_type == "epoch": + return (self._epoch - 1) // every_n + if step_type == "step": + return self._step // every_n + + error_msg = "Invalid `step_type`. Must be 'epoch' or 'step'." + raise ValueError(error_msg) + + @abstractmethod + def description(self) -> str: + """Description of the rollout scheduler.""" + error_msg = "`description` method not implemented by parent class." + raise NotImplementedError(error_msg) + + +class Static(RolloutScheduler): + """`Static` is a rollout scheduler that always returns the same rollout value.""" + + def __init__(self, rollout_value: int): + """ + `Static` is a rollout scheduler that always returns the same rollout value. + + Parameters + ---------- + rollout_value : int + Rollout value to return. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import Static + RollSched = Static(rollout_value = 5) + RollSched.rollout_at(epoch = 1) + # 5 + RollSched.rollout_at(epoch = 5) + # 5 + ``` + """ + self._rollout_value = rollout_value + + @property + def rollout(self) -> int: + return self._rollout_value + + @property + def maximum_rollout(self) -> int: + return self._rollout_value + + def description(self) -> str: + return f"Static rollout value of {self._rollout_value}." diff --git a/src/anemoi/training/schedulers/rollout/indexed.py b/src/anemoi/training/schedulers/rollout/indexed.py new file mode 100644 index 00000000..307d76f8 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/indexed.py @@ -0,0 +1,172 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any +from typing import Literal + +from anemoi.training.schedulers.rollout import RolloutScheduler + + +def get_closest_key(dictionary: dict[int, Any], key: int) -> int: + """ + Get the closest int key in a dictionary to a given key. + + Where the closest key is the one with the smallest absolute difference + and the key is less than or equal to the given key. + + Parameters + ---------- + dictionary : dict[int, Any] + Dictionary to search. + key : int + Key to search for. + + Returns + ------- + int + Closest key in the dictionary. + """ + return min(dictionary.keys(), key=lambda x: abs(x - key) if x <= key else float("inf")) + + +class PositionalIndexed(RolloutScheduler): + """ + `PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step. + + Once the list is exhausted, the rollout will remain at the last value. + """ + + def __init__( + self, + rollouts: list[int], + num_times_per_element: int = 1, + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step. + + Once the list is exhausted, the rollout will remain at the last value. + + Parameters + ---------- + rollouts : list[int] + List of rollout values. + num_times_per_element: int, optional + Number of times to remain at a element, by default 1 + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.indexed import PositionalIndexed + + RollSched = PositionalIndexed(rollouts = [1, 2, 3, 4], num_times_per_element = 2, step_type = 'epoch') + RollSched.at_epoch(1) + # 1 + RollSched.at_epoch(2) + # 1 + RollSched.at_epoch(3) + # 2 + ``` + """ + super().__init__() + self._rollouts = rollouts + self._num_times_per_element = num_times_per_element + self._step_type = step_type + + @property + def rollout(self) -> int: + count = self.count(self._num_times_per_element, self._step_type) + return self._rollouts[min(len(self._rollouts), count)] + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts) + + +class EpochPositionalIndexed(PositionalIndexed): + """Epoch based PositionalIndexed.""" + + def __init__(self, rollouts: list[int]): + super().__init__(rollouts, step_type="epoch") + + +class StepPositionalIndexed(PositionalIndexed): + """Step based PositionalIndexed.""" + + def __init__(self, rollouts: list[int]): + super().__init__(rollouts, step_type="step") + + +class Lookup(RolloutScheduler): + """ + `Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step. + + It will return the closest key that is less than or equal to the current epoch or step. + """ + + def __init__(self, rollouts: dict[int, int], step_type: Literal["step", "epoch"] = "epoch"): + """ + `Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step. + + It will return the closest key that is less than or equal to the current epoch or step. + + Parameters + ---------- + rollouts : dict[int, int] + Dictionary of rollouts. + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch' + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.indexed import Lookup + + RollSched = Lookup(rollouts = {0: 1, 5: 2, 10: 3}, step_type = 'epoch') + RollSched.at_epoch(1) + # 1 + RollSched.at_epoch(5) + # 2 + ``` + """ + super().__init__() + self._rollouts = rollouts + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._step_type == "epoch": + return self._rollouts.get(get_closest_key(self._rollouts, self._epoch), 1) + if self._step_type == "step": + return self._rollouts.get(get_closest_key(self._rollouts, self._step), 1) + + error_msg = "Invalid step_type. Must be 'epoch' or 'step'." + raise ValueError(error_msg) + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts.values()) + + +class EpochLookup(Lookup): + """Epoch based Lookup.""" + + def __init__(self, rollouts: dict[int, int]): + super().__init__(rollouts, step_type="epoch") + + +class StepLookup(Lookup): + """Step based Lookup.""" + + def __init__(self, rollouts: dict[int, int]): + super().__init__(rollouts, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py new file mode 100644 index 00000000..ec0eef71 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -0,0 +1,364 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: S608 + +from __future__ import annotations + +from typing import Literal + +import numpy as np +import pytorch_lightning as pl + +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.indexed import get_closest_key +from anemoi.training.utils.seeding import get_base_seed + + +class BaseRandom(RolloutScheduler): + """BaseRandom Scheduler.""" + + def __init__(self): + """ + Initialise the base random rollout scheduler. + + Set the seed with the environment variable `ANEMOI_BASE_SEED` if it exists, + """ + super().__init__() + + try: + seed = get_base_seed() + except AssertionError: + seed = 42 + + rnd_seed = pl.seed_everything(seed, workers=True) + self.rng = np.random.default_rng(rnd_seed) + + def broadcast(self, value: int) -> None: + """ + Broadcast the rollout value to all processes. + + Parameters + ---------- + value : int + Value to broadcast. + """ + # TODO(Harrison Cook): Need to broadcast the rollout to all processes + + def _randomly_pick(self, rollouts: list[int]) -> int: + """ + Randomly pick from a list of rollouts. + + Parameters + ---------- + rollouts : list[int] + s to choose from. + + Returns + ------- + int + Randomly selected rollout. + """ + rollout = self.rng.choice(rollouts) + self.broadcast(rollout) + return rollout + + +class RandomList(BaseRandom): + """`RandomList` is a rollout scheduler that randomly selects a rollout from a list of values.""" + + def __init__(self, rollouts: list[int]): + """ + RandomList is a rollout scheduler that randomly selects a rollout from a list of values. + + Parameters + ---------- + rollouts : list[int] + List of rollouts to choose from. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import RandomList + + RollSched = RandomList(rollouts = [1, 2, 3, 4, 5]) + RollSched.rollout_at(epoch = 1) + # any value in the list + RollSched.rollout_at(epoch = 2) + # any value in the list + ``` + """ + super().__init__() + self._rollouts = rollouts + + @property + def rollout(self) -> int: + return self._randomly_pick(self._rollouts) + + @property + def maximum_rollout(self) -> int: + return max(self._rollouts) + + @property + def current_maximum(self) -> int: + return self.maximum_rollout + + def description(self) -> str: + return f"Randomly select a rollout from {self._rollouts}" + + +class RandomRange(RandomList): + """`RandomRange` is a rollout scheduler that randomly selects a rollout from a range of values.""" + + def __init__(self, minimum: int = 1, maximum: int = 1, step: int = 1): + """ + RandomRange is a rollout scheduler that randomly selects a rollout from a range of values. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, by default 1 + step : int, optional + Step size for the range, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import RandomRange + + RollSched = RandomRange(minimum = 1, maximum = 5) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 5 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 5 + ``` + """ + super().__init__(list(range(minimum, maximum + 1, step))) + + def description(self) -> str: + return ( + "Randomly select a rollout from the " + f"{range(min(self._rollouts), max(self._rollouts) + 1, np.diff(self._rollouts)[0])}" + ) + + +class IncreasingRandom(BaseRandom): + """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n: int = 1, + increment: int | dict[int, int] = 1, + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `IncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n : int, optional + Number of steps or epochs to step the rollout value. + If `every_n` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'batch'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import IncreasingRandom + + RollSched = IncreasingRandom(minimum = 1, maximum = 10, step = 1, every_n_epochs = 1) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 2 + ``` + """ + super().__init__() + + if maximum <= -1: + maximum = float("inf") + + self._minimum = minimum + self._maximum = maximum + self._range_step = range_step + self._every_n = every_n + self._increment = increment + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._every_n == 0: + return self._minimum + + count_of_n = self.count(self._every_n, self._step_type) + + if isinstance(self._increment, int): + maximum_value = self._minimum + self._increment * count_of_n + else: + sum_of_increments = [ + self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) + ] + maximum_value = self._minimum + sum(sum_of_increments) + + rollouts = range(self._minimum, maximum_value + 1, self._range_step) + + return self._randomly_pick(rollouts) + + @property + def maximum_rollout(self) -> int: + return self._maximum + + @property + def current_maximum(self) -> int: + return self._minimum + ((self._epoch // self._every_n_epochs) * self._step) + + def description(self) -> str: + return ( + f"Randomly select a rollout from the increasing range " + f"{range(self._minimum, self._maximum, self._step)}" + f"with the upper bound increasing by {self._step} every {self._every_n} {self._step_type}" + ) + + +class EpochIncreasingRandom(IncreasingRandom): + """ + `EpochIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n epochs. + """ + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n_epochs: int = 1, + increment: int | dict[int, int] = 1, + ): + """ + EpochIncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n epochs. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n_epochs : int, optional + Number of epochs to step the rollout value. + If `every_n_epochs` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import EpochIncreasingRandom + + RollSched = EpochIncreasingRandom(minimum = 1, maximum = 10, range_step = 1, every_n_epochs = 1, increment = 1) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 2) + # any value between 1 and 2 + + RollSched = EpochIncreasingRandom( + minimum = 1, maximum = 10, range_step = 1, + every_n_epochs = 1, increment = {0: 0, 10: 1} + ) + RollSched.rollout_at(epoch = 1) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 9) + # any value between 1 and 1 + RollSched.rollout_at(epoch = 10) + # any value between 1 and 2, and then increments of 1 + ``` + """ + super().__init__(minimum, maximum, range_step, every_n_epochs, increment, step_type="epoch") + + +class StepIncreasingRandom(IncreasingRandom): + """ + `StepIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n steps. + """ + + def __init__( + self, + minimum: int = 1, + maximum: int = 1, + range_step: int = 1, + every_n_steps: int = 1, + increment: int | dict[int, int] = 1, + ): + """ + StepIncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. + + The maximum is incremented every n steps. + + Parameters + ---------- + minimum : int, optional + Minimum rollout to choose from, by default 1 + maximum : int, optional + Maximum rollout to choose from, can be -1 for no maximum, + by default 1. + range_step : int, optional + Step size for the range, by default 1 + every_n_steps : int, optional + Number of steps to step the rollout value. + If `every_n_steps` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by `every_n_epochs`, by default 1 + + Example + ------- + ```python + from anemoi.training.schedulers.rollout import StepIncreasingRandom + + RollSched = StepIncreasingRandom(minimum = 1, maximum = 10, range_step = 1, every_n_steps = 1, increment = 1) + RollSched.rollout_at(step = 1) + # any value between 1 and 1 + RollSched.rollout_at(step = 2) + # any value between 1 and 2 + + RollSched = StepIncreasingRandom( + minimum = 1, maximum = 10, range_step = 1, + every_n_steps = 1, increment = {0: 0, 10: 1} + ) + RollSched.rollout_at(step = 1) + # any value between 1 and 1 + RollSched.rollout_at(step = 9) + # any value between 1 and 1 + RollSched.rollout_at(step = 10) + # any value between 1 and 2, and then increments of 1 + ``` + """ + super().__init__(minimum, maximum, range_step, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py new file mode 100644 index 00000000..f07425e1 --- /dev/null +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -0,0 +1,155 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +from __future__ import annotations + +from typing import Literal + +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.indexed import get_closest_key + + +class Stepped(RolloutScheduler): + """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" + + def __init__( + self, + minimum: int, + maximum: int, + every_n: int, + increment: int | dict[int, int], + step_type: Literal["step", "epoch"] = "epoch", + ): + """ + `SteppedRollout` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs. + + Parameters + ---------- + minimum : int + Minimum rollout value. + maximum : int + Maximum rollout value. + Can be -1 for no maximum. + every_n : int + Number of steps or epochs to step the rollout value. + If `every_n` is 0, the rollout will stay at `minimum`. + increment : int | dict[int, int], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. + step_type : Literal['step', 'epoch'], optional + Type of step, either 'epoch' or 'step'. + by default 'epoch'. + + Example + ------- + ```python + from anemoi.training.schedulers.rollout.stepped import Stepped + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 5, increment = 1) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 5) + # 2 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 5, increment = 2) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 5) + # 3 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, increment = {0: 0, 10: 1}) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 9) + # 1 + RollSched.rollout_at(epoch = 10) + # 2, and then increments of 1 + ``` + """ + super().__init__() + + if maximum <= -1: + maximum = float("inf") + + self._minimum = minimum + self._maximum = maximum + self._every_n = every_n + self._increment = increment + self._step_type = step_type + + @property + def rollout(self) -> int: + if self._every_n == 0: + return self._minimum + + count_of_n = self.count(self._every_n, self._step_type) + + if isinstance(self._increment, int): + return min(self._maximum, self._minimum + self._increment * count_of_n) + + sum_of_increments = [ + self._increment.get(get_closest_key(self._increment, i + 1 if self._step_type == "epoch" else i)) + for i in range(count_of_n) + ] + return min(self._maximum, self._minimum + sum(sum_of_increments)) + + @property + def maximum_rollout(self) -> int: + return self._maximum + + def description(self) -> str: + return ( + "Stepped rollout scheduler stepping between" + f"{self._minimum} and {self._maximum} by {self._increment} for {self._every_n} {self._step_type}s." + ) + + +class EpochStepped(Stepped): + """`EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs.""" + + def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: int = 1): + """ + `EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs. + + Parameters + ---------- + minimum : int + The minimum value for the scheduler. + maximum : int + The maximum value for the scheduler. + every_n_epochs : int, optional + The number of epochs after which the value is incremented, by default 1. + increment : int, optional + The amount by which the value is incremented, by default 1. + """ + super().__init__(minimum, maximum, every_n_epochs, increment, step_type="epoch") + + +class StepStepped(Stepped): + """`StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps.""" + + def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: int = 1): + """ + `StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps. + + Parameters + ---------- + minimum : int + The minimum value for the scheduler. + maximum : int + The maximum value for the scheduler. + every_n_steps : int, optional + The number of steps after which the value is incremented, by default 1000. + increment : int, optional + The amount by which the value is incremented, by default 1. + """ + super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 0059d90a..4351fb35 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -12,6 +12,7 @@ from collections import defaultdict from collections.abc import Generator from collections.abc import Mapping +from typing import TYPE_CHECKING from typing import Optional from typing import Union @@ -38,6 +39,9 @@ LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from anemoi.training.training.schedulers.rollout import RolloutScheduler + class GraphForecaster(pl.LightningModule): """Graph neural network forecaster for PyTorch Lightning.""" @@ -146,18 +150,15 @@ def __init__( self.warmup_t = getattr(config.training.lr, "warmup_t", 1000) self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min - self.rollout = config.training.rollout.start - self.rollout_epoch_increment = config.training.rollout.epoch_increment - self.rollout_max = config.training.rollout.max + + self.rollout: RolloutScheduler = instantiate(config.training.rollout) self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None self.reader_groups = None - LOGGER.debug("Rollout window length: %d", self.rollout) - LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) - LOGGER.debug("Rollout max : %d", self.rollout_max) + LOGGER.debug("Rollout config: %d", self.rollout.description()) LOGGER.debug("Multistep: %d", self.multi_step) # lazy init model and reader group info, will be set by the DDPGroupStrategy: @@ -451,7 +452,7 @@ def rollout_step( ) assert batch.shape[1] >= rollout + self.multi_step, msg - for rollout_step in range(rollout or self.rollout): + for rollout_step in range(rollout or int(self.rollout)): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) @@ -485,7 +486,7 @@ def _step( for loss_next, metrics_next, y_preds_next in self.rollout_step( batch, - rollout=self.rollout, + rollout=int(self.rollout), training_mode=True, validation_mode=validation_mode, ): @@ -493,7 +494,8 @@ def _step( metrics.update(metrics_next) y_preds.extend(y_preds_next) - loss *= 1.0 / self.rollout + loss *= 1.0 / int(self.rollout) + self.rollout.step() return loss, metrics, y_preds def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: @@ -619,7 +621,7 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: ) self.log( "rollout", - float(self.rollout), + int(self.rollout), on_step=True, logger=self.logger_enabled, rank_zero_only=True, @@ -642,10 +644,7 @@ def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: None = None) - scheduler.step(epoch=self.trainer.global_step) def on_train_epoch_end(self) -> None: - if self.rollout_epoch_increment > 0 and self.current_epoch % self.rollout_epoch_increment == 0: - self.rollout += 1 - LOGGER.debug("Rollout window length: %d", self.rollout) - self.rollout = min(self.rollout, self.rollout_max) + self.rollout.step_epoch() def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: """ diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 694fb2da..c638bc11 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -328,7 +328,7 @@ def _log_information(self) -> None: "Effective learning rate: %.3e", int(total_number_of_model_instances) * self.config.training.lr.rate, ) - LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) + LOGGER.debug("Rollout config: %d", self.config.training.rollout) if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): LOGGER.info( diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/tests/schedulers/rollout/__init__.py b/tests/schedulers/rollout/__init__.py new file mode 100644 index 00000000..c167afa2 --- /dev/null +++ b/tests/schedulers/rollout/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. From fcf1f1fe7347d450d21994cc9fc8896a4e11f0a7 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 16:53:48 +0000 Subject: [PATCH 2/6] Incrementer - Allow for complex incrementing setup --- .../training/config/training/default.yaml | 13 +- .../training/schedulers/rollout/__init__.py | 26 ++-- .../training/schedulers/rollout/indexed.py | 8 +- .../training/schedulers/rollout/randomise.py | 36 +++--- .../training/schedulers/rollout/stepped.py | 117 +++++++++++++++--- 5 files changed, 150 insertions(+), 50 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 6c915eb5..27ff7e22 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -86,11 +86,16 @@ scale_validation_metrics: # length of the "rollout" window (see Keisler's paper) rollout: - start: 1 + _target_: anemoi.training.schedulers.stepped.EpochStepped + minimum: 1 + maximum: 12 # increase rollout every n epochs - epoch_increment: 0 - # maximum rollout to use - max: 1 + every_n_epochs: 1 + # increment + increment: + step: + 0: 0 + 200000: 1 # Set max_epochs or max_steps. Training stops at the first limit reached. max_epochs: null diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index 9da65b2e..e4e9ec06 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -94,16 +94,16 @@ def step_epoch(self, count: int = 1, /) -> None: """Step the scheduler by a count of epochs.""" self._epoch += count - def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: + def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: """ Get the count of steps or epochs. Parameters ---------- - every_n : int - Every n steps or epochs. - step_type : _type_, optional - Which to count, by default Literal['step', 'epoch'] + n_epochs : int | None, optional + Number of epochs to count, by default None + n_steps : int | None, optional + Number of steps to count, by default None Returns ------- @@ -113,15 +113,17 @@ def count(self, every_n: int, step_type: Literal["step", "epoch"]) -> int: Raises ------ ValueError - If the step_type is not 'step' or 'epoch'. + If both `n_epochs` and `n_steps` are given, or if neither are given. """ - if step_type == "epoch": - return (self._epoch - 1) // every_n - if step_type == "step": - return self._step // every_n + if n_epochs is not None and n_steps is not None or n_epochs is None and n_steps is None: + error_msg = "Only one of `n_epochs` or `n_steps` can be given." + raise ValueError(error_msg) + + if n_epochs is not None: + return self._epoch // n_epochs + if n_steps is not None: + return self._step // n_steps - error_msg = "Invalid `step_type`. Must be 'epoch' or 'step'." - raise ValueError(error_msg) @abstractmethod def description(self) -> str: diff --git a/src/anemoi/training/schedulers/rollout/indexed.py b/src/anemoi/training/schedulers/rollout/indexed.py index 307d76f8..782aa27a 100644 --- a/src/anemoi/training/schedulers/rollout/indexed.py +++ b/src/anemoi/training/schedulers/rollout/indexed.py @@ -84,7 +84,13 @@ def __init__( @property def rollout(self) -> int: - count = self.count(self._num_times_per_element, self._step_type) + if self._step_type == "epoch": + count = self.count(n_epochs=self._num_times_per_element) + elif self._step_type == "step": + count = self.count(n_steps=self._num_times_per_element) + else: + error_msg = "Invalid step_type. Must be 'epoch' or 'step'." + raise ValueError(error_msg) return self._rollouts[min(len(self._rollouts), count)] @property diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index ec0eef71..efa0efaa 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -16,9 +16,9 @@ import numpy as np import pytorch_lightning as pl -from anemoi.training.schedulers.rollout import RolloutScheduler -from anemoi.training.schedulers.rollout.indexed import get_closest_key from anemoi.training.utils.seeding import get_base_seed +from anemoi.training.schedulers.rollout import RolloutScheduler +from anemoi.training.schedulers.rollout.stepped import BaseIncrementingRolloutScheduler, VALID_INCREMENT_TYPE, VALID_STEP_TYPES class BaseRandom(RolloutScheduler): @@ -150,7 +150,7 @@ def description(self) -> str: ) -class IncreasingRandom(BaseRandom): +class IncreasingRandom(BaseIncrementingRolloutScheduler, BaseRandom): """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" def __init__( @@ -159,8 +159,9 @@ def __init__( maximum: int = 1, range_step: int = 1, every_n: int = 1, - increment: int | dict[int, int] = 1, - step_type: Literal["step", "epoch"] = "epoch", + increment: VALID_INCREMENT_TYPE = 1, + *, + step_type: VALID_STEP_TYPES = "epoch", ): """ `IncreasingRandom` is a rollout scheduler that randomly selects a rollout from an increasing range of values. @@ -177,7 +178,7 @@ def __init__( every_n : int, optional Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. - increment : int | dict[int, int], optional + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional Value to increment the rollout by `every_n_epochs`, by default 1 step_type : Literal['step', 'epoch'], optional Type of step, either 'epoch' or 'batch'. @@ -195,7 +196,7 @@ def __init__( # any value between 1 and 2 ``` """ - super().__init__() + super().__init__(every_n = every_n, increment = increment, step_type = step_type) if maximum <= -1: maximum = float("inf") @@ -203,26 +204,23 @@ def __init__( self._minimum = minimum self._maximum = maximum self._range_step = range_step - self._every_n = every_n - self._increment = increment - self._step_type = step_type @property def rollout(self) -> int: if self._every_n == 0: return self._minimum - count_of_n = self.count(self._every_n, self._step_type) + # count_of_n = self.count(self._every_n, self._step_type) - if isinstance(self._increment, int): - maximum_value = self._minimum + self._increment * count_of_n - else: - sum_of_increments = [ - self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) - ] - maximum_value = self._minimum + sum(sum_of_increments) + # if isinstance(self._increment, int): + # maximum_value = self._minimum + self._increment * count_of_n + # else: + # sum_of_increments = [ + # self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) + # ] + # maximum_value = self._minimum + sum(sum_of_increments) - rollouts = range(self._minimum, maximum_value + 1, self._range_step) + rollouts = range(self._minimum, self._minimum + self.total_increment, self._range_step) return self._randomly_pick(rollouts) diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index f07425e1..3a417395 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -14,7 +14,79 @@ from anemoi.training.schedulers.rollout.indexed import get_closest_key -class Stepped(RolloutScheduler): +VALID_STEP_TYPE = ["step", "epoch"] +VALID_STEP_TYPES = Literal["step", "epoch"] + +VALID_INCREMENT_TYPE = int | dict[int, int] | dict[VALID_STEP_TYPES, dict[int, int]] + +class BaseIncrementingRolloutScheduler(RolloutScheduler): + """Base class for schedulers that have an incrementing value.""" + _increment_value = 0 + + def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_INCREMENT_TYPE = 1): + super().__init__() + + if step_type not in VALID_STEP_TYPE: + error_msg = "Step type must be either 'step' or 'epoch'." + raise ValueError(error_msg) + + if isinstance(increment, dict): + if not len(increment) == 1: + error_msg = ( + "Increment dictionary cannot be empty, nor can it contain more then one entry." + "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." + ) + raise ValueError(error_msg) + + self._every_n = every_n + self._step_type = step_type + self._increment = increment + + + @property + def total_increment(self) -> int: + return self._increment_value + + def _get_current_increment(self): + if isinstance(self._increment, int): + return self._increment + + if isinstance(list(self._increment.keys())[0], int): + current_value = self._step if self._step_type == 'step' else self._epoch + return get_closest_key(self._increment, current_value) + + elif isinstance(list(self._increment.keys())[0], str): + step_type = list(self._increment.keys())[0] + if step_type not in ['step', 'epoch']: + error_msg = "Increment dictionary keys must be either 'step' or 'epoch'." + raise ValueError(error_msg) + + current_value = self._step if step_type == 'step' else self._epoch + increment_dict = self._increment[step_type] + return increment_dict.get(get_closest_key(increment_dict, current_value), 0) + else: + error_msg = "Increment dictionary keys must be either int or str." + raise ValueError(error_msg) + + + def step(self, count = 1): + super().step(count) + if self._every_n == 0: + return + + if self._step_type == 'step' and self._step % self._every_n == 0: + self._increment_value += self._get_current_increment() + + + def step_epoch(self, count = 1): + super().step_epoch(count) + if self._every_n == 0: + return + + if self._step_type == 'epoch' and self._epoch % self._every_n == 0: + self._increment_value += self._get_current_increment() + +class Stepped(BaseIncrementingRolloutScheduler): """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" def __init__( @@ -22,8 +94,9 @@ def __init__( minimum: int, maximum: int, every_n: int, - increment: int | dict[int, int], - step_type: Literal["step", "epoch"] = "epoch", + increment: VALID_INCREMENT_TYPE = 1, + *, + step_type: VALID_STEP_TYPES = "epoch", ): """ `SteppedRollout` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs. @@ -38,7 +111,7 @@ def __init__( every_n : int Number of steps or epochs to step the rollout value. If `every_n` is 0, the rollout will stay at `minimum`. - increment : int | dict[int, int], optional + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional Value to increment the rollout by. Can be an int or dictionary, where the keys represent the value of `step_type` and the values represent the increment. @@ -73,21 +146,27 @@ def __init__( # 1 RollSched.rollout_at(epoch = 10) # 2, and then increments of 1 + + RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}}) + RollSched.rollout_at(epoch = 2) + # 1 + RollSched.rollout_at(epoch = 2, step = 1000) + # 2 + ``` """ - super().__init__() + super().__init__(every_n=every_n, step_type=step_type, increment=increment) if maximum <= -1: maximum = float("inf") self._minimum = minimum self._maximum = maximum - self._every_n = every_n - self._increment = increment - self._step_type = step_type @property def rollout(self) -> int: + return min(self._maximum, self._minimum + self.total_increment) + if self._every_n == 0: return self._minimum @@ -116,7 +195,7 @@ def description(self) -> str: class EpochStepped(Stepped): """`EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs.""" - def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: int = 1): + def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, increment: VALID_INCREMENT_TYPE = 1): """ `EpochStepped` is a rollout scheduler that steps the rollout value at the end of each n epochs. @@ -128,8 +207,13 @@ def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, incremen The maximum value for the scheduler. every_n_epochs : int, optional The number of epochs after which the value is incremented, by default 1. - increment : int, optional - The amount by which the value is incremented, by default 1. + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. """ super().__init__(minimum, maximum, every_n_epochs, increment, step_type="epoch") @@ -137,7 +221,7 @@ def __init__(self, minimum: int, maximum: int, every_n_epochs: int = 1, incremen class StepStepped(Stepped): """`StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps.""" - def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: int = 1): + def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increment: VALID_INCREMENT_TYPE = 1): """ `StepStepped` is a rollout scheduler that steps the rollout value at the end of each n steps. @@ -149,7 +233,12 @@ def __init__(self, minimum: int, maximum: int, every_n_steps: int = 1000, increm The maximum value for the scheduler. every_n_steps : int, optional The number of steps after which the value is incremented, by default 1000. - increment : int, optional - The amount by which the value is incremented, by default 1. + increment : int | dict[int, int] | dict[Literal['step', 'epoch'], dict[int, int]], optional + Value to increment the rollout by. + Can be an int or dictionary, where the keys represent the value of `step_type` + and the values represent the increment. + Will round down to the closest key. + i.e. {0: 1, 10: 2} will increment by 1 until 10, then by 2. + by default 1. """ super().__init__(minimum, maximum, every_n_steps, increment, step_type="step") From a712c4864422830e69fcffc0732d0b0ab57ccbfc Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:39:20 +0000 Subject: [PATCH 3/6] Improve incrementor - Calculation based not step based --- .../training/config/training/default.yaml | 4 +- .../training/schedulers/rollout/randomise.py | 24 +-- .../training/schedulers/rollout/stepped.py | 154 ++++++++++-------- 3 files changed, 94 insertions(+), 88 deletions(-) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 27ff7e22..7eb79077 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -91,11 +91,11 @@ rollout: maximum: 12 # increase rollout every n epochs every_n_epochs: 1 - # increment + # Control the incrementing of the rollout window increment: step: 0: 0 - 200000: 1 + 200000: 1 # After 200k steps, increment by 1 every 1 epoch # Set max_epochs or max_steps. Training stops at the first limit reached. max_epochs: null diff --git a/src/anemoi/training/schedulers/rollout/randomise.py b/src/anemoi/training/schedulers/rollout/randomise.py index efa0efaa..901b0696 100644 --- a/src/anemoi/training/schedulers/rollout/randomise.py +++ b/src/anemoi/training/schedulers/rollout/randomise.py @@ -11,14 +11,14 @@ from __future__ import annotations -from typing import Literal - import numpy as np import pytorch_lightning as pl -from anemoi.training.utils.seeding import get_base_seed from anemoi.training.schedulers.rollout import RolloutScheduler -from anemoi.training.schedulers.rollout.stepped import BaseIncrementingRolloutScheduler, VALID_INCREMENT_TYPE, VALID_STEP_TYPES +from anemoi.training.schedulers.rollout.stepped import VALID_INCREMENT_TYPE +from anemoi.training.schedulers.rollout.stepped import VALID_STEP_TYPES +from anemoi.training.schedulers.rollout.stepped import IncrementMixin +from anemoi.training.utils.seeding import get_base_seed class BaseRandom(RolloutScheduler): @@ -150,7 +150,7 @@ def description(self) -> str: ) -class IncreasingRandom(BaseIncrementingRolloutScheduler, BaseRandom): +class IncreasingRandom(IncrementMixin, BaseRandom): """IncreasingRandom is a rollout scheduler that randomly selects a rollout from an increasing range of values.""" def __init__( @@ -196,7 +196,7 @@ def __init__( # any value between 1 and 2 ``` """ - super().__init__(every_n = every_n, increment = increment, step_type = step_type) + super().__init__(every_n=every_n, increment=increment, step_type=step_type) if maximum <= -1: maximum = float("inf") @@ -210,17 +210,7 @@ def rollout(self) -> int: if self._every_n == 0: return self._minimum - # count_of_n = self.count(self._every_n, self._step_type) - - # if isinstance(self._increment, int): - # maximum_value = self._minimum + self._increment * count_of_n - # else: - # sum_of_increments = [ - # self._increment.get(get_closest_key(self._increment, i + 1)) for i in range(count_of_n) - # ] - # maximum_value = self._minimum + sum(sum_of_increments) - - rollouts = range(self._minimum, self._minimum + self.total_increment, self._range_step) + rollouts = range(self._minimum, self._minimum + self.increment(self._step, self._epoch), self._range_step) return self._randomly_pick(rollouts) diff --git a/src/anemoi/training/schedulers/rollout/stepped.py b/src/anemoi/training/schedulers/rollout/stepped.py index 3a417395..6192fa30 100644 --- a/src/anemoi/training/schedulers/rollout/stepped.py +++ b/src/anemoi/training/schedulers/rollout/stepped.py @@ -13,15 +13,14 @@ from anemoi.training.schedulers.rollout import RolloutScheduler from anemoi.training.schedulers.rollout.indexed import get_closest_key - VALID_STEP_TYPE = ["step", "epoch"] VALID_STEP_TYPES = Literal["step", "epoch"] VALID_INCREMENT_TYPE = int | dict[int, int] | dict[VALID_STEP_TYPES, dict[int, int]] -class BaseIncrementingRolloutScheduler(RolloutScheduler): - """Base class for schedulers that have an incrementing value.""" - _increment_value = 0 + +class IncrementMixin: + """Mixin class for schedulers that have an incrementing value based on the steps and epochs.""" def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_INCREMENT_TYPE = 1): super().__init__() @@ -30,63 +29,91 @@ def __init__(self, every_n: int, step_type: VALID_STEP_TYPES, increment: VALID_I error_msg = "Step type must be either 'step' or 'epoch'." raise ValueError(error_msg) - if isinstance(increment, dict): - if not len(increment) == 1: - error_msg = ( - "Increment dictionary cannot be empty, nor can it contain more then one entry." - "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." - ) - raise ValueError(error_msg) + if isinstance(increment, dict) and len(increment) == 0: + error_msg = ( + "Increment dictionary cannot be empty." + "\nIt should either be a dictionary of ints or contain a single key of 'step' or 'epoch'." + ) + raise ValueError(error_msg) self._every_n = every_n self._step_type = step_type self._increment = increment - - @property - def total_increment(self) -> int: - return self._increment_value + def increment(self, step: int, epoch: int) -> int: + """ + Get the increment value for a particular step or epoch. + + Relies on the number of steps per epochs to calculate the increment + when the step_type of the increment is different from the stepper step_type. + - def _get_current_increment(self): + Parameters + ---------- + step : int + Step number. + epoch : int + Epoch number. + + Returns + ------- + int + Increment value. + + Raises + ------ + ValueError + If cannot parse the `increment` value given at init. + """ if isinstance(self._increment, int): return self._increment - if isinstance(list(self._increment.keys())[0], int): - current_value = self._step if self._step_type == 'step' else self._epoch - return get_closest_key(self._increment, current_value) - - elif isinstance(list(self._increment.keys())[0], str): - step_type = list(self._increment.keys())[0] - if step_type not in ['step', 'epoch']: + count = (step // self._every_n if self._step_type == "step" else epoch // self._every_n) + 1 + + if isinstance(next(iter(self._increment.keys())), int): + return sum( + (self._increment.get(get_closest_key(self._increment, i * self._every_n), 0) for i in range(count)), + ) + + if isinstance(next(iter(self._increment.keys())), str): + increment_step_type = next(iter(self._increment.keys())) + if increment_step_type not in ["step", "epoch"]: error_msg = "Increment dictionary keys must be either 'step' or 'epoch'." raise ValueError(error_msg) - - current_value = self._step if step_type == 'step' else self._epoch - increment_dict = self._increment[step_type] - return increment_dict.get(get_closest_key(increment_dict, current_value), 0) - else: - error_msg = "Increment dictionary keys must be either int or str." - raise ValueError(error_msg) - - - def step(self, count = 1): - super().step(count) - if self._every_n == 0: - return - - if self._step_type == 'step' and self._step % self._every_n == 0: - self._increment_value += self._get_current_increment() - - - def step_epoch(self, count = 1): - super().step_epoch(count) - if self._every_n == 0: - return - - if self._step_type == 'epoch' and self._epoch % self._every_n == 0: - self._increment_value += self._get_current_increment() - -class Stepped(BaseIncrementingRolloutScheduler): + + increment_dict = self._increment[increment_step_type] + + if increment_step_type == self._step_type: + return sum( + (increment_dict.get(get_closest_key(increment_dict, i * self._every_n), 0) for i in range(count)), + ) + + if epoch == 0 or step == 0: + return 0 + + num_steps_per_epoch = step / epoch + if increment_step_type == "step" and self._step_type == "epoch": + return sum( + increment_dict.get( + get_closest_key(increment_dict, (i * self._every_n) * num_steps_per_epoch), + 0, + ) + for i in range(count) + ) + if increment_step_type == "epoch" and self._step_type == "step": + return sum( + increment_dict.get( + get_closest_key(increment_dict, (i * self._every_n) // num_steps_per_epoch), + 0, + ) + for i in range(count) + ) + + error_msg = "Increment dictionary keys must be either int or a single str." + raise TypeError(error_msg) + + +class Stepped(RolloutScheduler, IncrementMixin): """`Stepped` is a base rollout scheduler that steps the rollout value at the end of each n steps or epochs.""" def __init__( @@ -147,8 +174,11 @@ def __init__( RollSched.rollout_at(epoch = 10) # 2, and then increments of 1 - RollSched = Stepped(minimum = 1, maximum = 10, every_n = 1, step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}}) - RollSched.rollout_at(epoch = 2) + RollSched = Stepped( + minimum = 1, maximum = 10, every_n = 1, + step_type = 'epoch', increment = {'step':{0: 0, 1000: 1}} + ) + RollSched.rollout_at(epoch = 1, step = 500 ) # 1 RollSched.rollout_at(epoch = 2, step = 1000) # 2 @@ -165,21 +195,7 @@ def __init__( @property def rollout(self) -> int: - return min(self._maximum, self._minimum + self.total_increment) - - if self._every_n == 0: - return self._minimum - - count_of_n = self.count(self._every_n, self._step_type) - - if isinstance(self._increment, int): - return min(self._maximum, self._minimum + self._increment * count_of_n) - - sum_of_increments = [ - self._increment.get(get_closest_key(self._increment, i + 1 if self._step_type == "epoch" else i)) - for i in range(count_of_n) - ] - return min(self._maximum, self._minimum + sum(sum_of_increments)) + return min(self._maximum, self._minimum + self.increment(self._step, self._epoch)) @property def maximum_rollout(self) -> int: @@ -187,8 +203,8 @@ def maximum_rollout(self) -> int: def description(self) -> str: return ( - "Stepped rollout scheduler stepping between" - f"{self._minimum} and {self._maximum} by {self._increment} for {self._every_n} {self._step_type}s." + "Stepped rollout scheduler stepping between " + f"{self._minimum} and {self._maximum} by {self._increment} for every {self._every_n} {self._step_type}/s." ) From 72e0bf9e1c32349bd64d95fdcbbeb40f8b621fb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:41:36 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/training/schedulers/rollout/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index e4e9ec06..20bce2e3 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -115,7 +115,7 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: ValueError If both `n_epochs` and `n_steps` are given, or if neither are given. """ - if n_epochs is not None and n_steps is not None or n_epochs is None and n_steps is None: + if (n_epochs is not None and n_steps is not None) or (n_epochs is None and n_steps is None): error_msg = "Only one of `n_epochs` or `n_steps` can be given." raise ValueError(error_msg) @@ -124,7 +124,6 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: if n_steps is not None: return self._step // n_steps - @abstractmethod def description(self) -> str: """Description of the rollout scheduler.""" From c199c0e1306c3eb31690c7194a88016019adc36a Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:44:13 +0000 Subject: [PATCH 5/6] Precommit fixes --- src/anemoi/training/schedulers/rollout/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/anemoi/training/schedulers/rollout/__init__.py b/src/anemoi/training/schedulers/rollout/__init__.py index 20bce2e3..464ed71f 100644 --- a/src/anemoi/training/schedulers/rollout/__init__.py +++ b/src/anemoi/training/schedulers/rollout/__init__.py @@ -11,7 +11,6 @@ from abc import ABC from abc import abstractmethod -from typing import Literal class RolloutScheduler(ABC): @@ -121,8 +120,7 @@ def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int: if n_epochs is not None: return self._epoch // n_epochs - if n_steps is not None: - return self._step // n_steps + return self._step // n_steps @abstractmethod def description(self) -> str: From 69a5d9a03f8c7560986b9e269f2bfd21a9d51c75 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 18 Dec 2024 17:45:44 +0000 Subject: [PATCH 6/6] Add changelog entry --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 286fd915..e35e40b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ Keep it human-readable, your future self will thank you! - Add supporting arrrays (numpy) to checkpoint - Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171) - Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202) +- Rollout Schedulers [#206](https://github.com/ecmwf/anemoi-training/pull/206) + ### Changed