Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

feat(rollout)!: Rollout Schedulers #206

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Keep it human-readable, your future self will thank you!
- Add supporting arrrays (numpy) to checkpoint
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)
- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202)
- Rollout Schedulers [#206](https://github.com/ecmwf/anemoi-training/pull/206)


### Changed

Expand Down
13 changes: 9 additions & 4 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,16 @@ scale_validation_metrics:

# length of the "rollout" window (see Keisler's paper)
rollout:
start: 1
_target_: anemoi.training.schedulers.stepped.EpochStepped
minimum: 1
maximum: 12
# increase rollout every n epochs
epoch_increment: 0
# maximum rollout to use
max: 1
every_n_epochs: 1
# Control the incrementing of the rollout window
increment:
step:
0: 0
200000: 1 # After 200k steps, increment by 1 every 1 epoch

# Set max_epochs or max_steps. Training stops at the first limit reached.
max_epochs: null
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def _log_collector(self) -> None:
log_capture_time_counter = 0

def _store_buffered_logs(self) -> None:
_buffer_size = self._io_buffer.tell()
if not _buffer_size:
buffer_size = self._io_buffer.tell()
if not buffer_size:
return
self._io_buffer.seek(0)
# read and reset the buffer
data = self._io_buffer.read(_buffer_size)
data = self._io_buffer.read(buffer_size)
self._io_buffer.seek(0)
# handle the buffered data and store
# split lines and keep \n at the end of each line
Expand Down
166 changes: 166 additions & 0 deletions src/anemoi/training/schedulers/rollout/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

from abc import ABC
from abc import abstractmethod


class RolloutScheduler(ABC):
"""
`RolloutScheduler` is an abstract base class for rollout schedulers.

A rollout scheduler is an object that manages the rollout of a training loop.

```python
RollSched = RolloutScheduler()

for epoch in range(20):
for step in range(100):
y = model(x, rollout = RollSched.rollout)

RollSched.step()
RollSched.step_epoch()
```
"""

_epoch: int = 0
_step: int = 0

@property
@abstractmethod
def rollout(self) -> int:
"""Get the current rollout value."""
error_msg = "`rollout` property not implemented by parent class."
raise NotImplementedError(error_msg)

@property
@abstractmethod
def maximum_rollout(self) -> int:
"""Get maximum rollout possible."""
error_msg = "`maximum_rollout` property not implemented by parent class."
raise NotImplementedError(error_msg)

@property
def current_maximum(self) -> int:
"""Get the current maximum rollout value."""
return self.rollout

def __int__(self) -> int:
return int(self.rollout)

def rollout_at(self, step: int | None = None, epoch: int | None = None) -> int:
"""
Get the rollout at a specific step and epoch.

Parameters
----------
step : int, optional
Step value to override with, by default None
epoch : int, optional
Epoch value to override with, by default None

Returns
-------
int
Rollout value at the specified step and epoch.
"""
step_ = self._step
epoch_ = self._epoch

self._step = step if step is not None else step_
self._epoch = epoch if epoch is not None else epoch_

rollout = self.rollout

self._step = step_
self._epoch = epoch_

return rollout

def step(self, count: int = 1, /) -> None:
"""Step the scheduler by a count."""
self._step += count

def step_epoch(self, count: int = 1, /) -> None:
"""Step the scheduler by a count of epochs."""
self._epoch += count

def count(self, n_epochs: int | None = None, n_steps: int | None = None) -> int:
"""
Get the count of steps or epochs.

Parameters
----------
n_epochs : int | None, optional
Number of epochs to count, by default None
n_steps : int | None, optional
Number of steps to count, by default None

Returns
-------
int
Count of steps or epochs.

Raises
------
ValueError
If both `n_epochs` and `n_steps` are given, or if neither are given.
"""
if (n_epochs is not None and n_steps is not None) or (n_epochs is None and n_steps is None):
error_msg = "Only one of `n_epochs` or `n_steps` can be given."
raise ValueError(error_msg)

if n_epochs is not None:
return self._epoch // n_epochs
return self._step // n_steps

@abstractmethod
def description(self) -> str:
"""Description of the rollout scheduler."""
error_msg = "`description` method not implemented by parent class."
raise NotImplementedError(error_msg)


class Static(RolloutScheduler):
"""`Static` is a rollout scheduler that always returns the same rollout value."""

def __init__(self, rollout_value: int):
"""
`Static` is a rollout scheduler that always returns the same rollout value.

Parameters
----------
rollout_value : int
Rollout value to return.

Example
-------
```python
from anemoi.training.schedulers.rollout import Static
RollSched = Static(rollout_value = 5)
RollSched.rollout_at(epoch = 1)
# 5
RollSched.rollout_at(epoch = 5)
# 5
```
"""
self._rollout_value = rollout_value

@property
def rollout(self) -> int:
return self._rollout_value

@property
def maximum_rollout(self) -> int:
return self._rollout_value

def description(self) -> str:
return f"Static rollout value of {self._rollout_value}."
178 changes: 178 additions & 0 deletions src/anemoi/training/schedulers/rollout/indexed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from typing import Any
from typing import Literal

from anemoi.training.schedulers.rollout import RolloutScheduler


def get_closest_key(dictionary: dict[int, Any], key: int) -> int:
"""
Get the closest int key in a dictionary to a given key.

Where the closest key is the one with the smallest absolute difference
and the key is less than or equal to the given key.

Parameters
----------
dictionary : dict[int, Any]
Dictionary to search.
key : int
Key to search for.

Returns
-------
int
Closest key in the dictionary.
"""
return min(dictionary.keys(), key=lambda x: abs(x - key) if x <= key else float("inf"))


class PositionalIndexed(RolloutScheduler):
"""
`PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step.

Once the list is exhausted, the rollout will remain at the last value.
"""

def __init__(
self,
rollouts: list[int],
num_times_per_element: int = 1,
step_type: Literal["step", "epoch"] = "epoch",
):
"""
`PositionalIndexed` retrieves the rollout value from a list of rollouts based on the current epoch or step.

Once the list is exhausted, the rollout will remain at the last value.

Parameters
----------
rollouts : list[int]
List of rollout values.
num_times_per_element: int, optional
Number of times to remain at a element, by default 1
step_type : Literal['step', 'epoch'], optional
Type of step, either 'epoch' or 'step'.
by default 'epoch'.

Example
-------
```python
from anemoi.training.schedulers.rollout.indexed import PositionalIndexed

RollSched = PositionalIndexed(rollouts = [1, 2, 3, 4], num_times_per_element = 2, step_type = 'epoch')
RollSched.at_epoch(1)
# 1
RollSched.at_epoch(2)
# 1
RollSched.at_epoch(3)
# 2
```
"""
super().__init__()
self._rollouts = rollouts
self._num_times_per_element = num_times_per_element
self._step_type = step_type

@property
def rollout(self) -> int:
if self._step_type == "epoch":
count = self.count(n_epochs=self._num_times_per_element)
elif self._step_type == "step":
count = self.count(n_steps=self._num_times_per_element)
else:
error_msg = "Invalid step_type. Must be 'epoch' or 'step'."
raise ValueError(error_msg)
return self._rollouts[min(len(self._rollouts), count)]

@property
def maximum_rollout(self) -> int:
return max(self._rollouts)


class EpochPositionalIndexed(PositionalIndexed):
"""Epoch based PositionalIndexed."""

def __init__(self, rollouts: list[int]):
super().__init__(rollouts, step_type="epoch")


class StepPositionalIndexed(PositionalIndexed):
"""Step based PositionalIndexed."""

def __init__(self, rollouts: list[int]):
super().__init__(rollouts, step_type="step")


class Lookup(RolloutScheduler):
"""
`Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step.

It will return the closest key that is less than or equal to the current epoch or step.
"""

def __init__(self, rollouts: dict[int, int], step_type: Literal["step", "epoch"] = "epoch"):
"""
`Lookup` retrieves the rollout value from a dictionary of rollouts based on the current epoch or step.

It will return the closest key that is less than or equal to the current epoch or step.

Parameters
----------
rollouts : dict[int, int]
Dictionary of rollouts.
step_type : Literal['step', 'epoch'], optional
Type of step, either 'epoch' or 'step'.
by default 'epoch'

Example
-------
```python
from anemoi.training.schedulers.rollout.indexed import Lookup

RollSched = Lookup(rollouts = {0: 1, 5: 2, 10: 3}, step_type = 'epoch')
RollSched.at_epoch(1)
# 1
RollSched.at_epoch(5)
# 2
```
"""
super().__init__()
self._rollouts = rollouts
self._step_type = step_type

@property
def rollout(self) -> int:
if self._step_type == "epoch":
return self._rollouts.get(get_closest_key(self._rollouts, self._epoch), 1)
if self._step_type == "step":
return self._rollouts.get(get_closest_key(self._rollouts, self._step), 1)

error_msg = "Invalid step_type. Must be 'epoch' or 'step'."
raise ValueError(error_msg)

@property
def maximum_rollout(self) -> int:
return max(self._rollouts.values())


class EpochLookup(Lookup):
"""Epoch based Lookup."""

def __init__(self, rollouts: dict[int, int]):
super().__init__(rollouts, step_type="epoch")


class StepLookup(Lookup):
"""Step based Lookup."""

def __init__(self, rollouts: dict[int, int]):
super().__init__(rollouts, step_type="step")
Loading
Loading