Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Update sanity checks for training data consistency #120

Merged
merged 15 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 12 additions & 27 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,78 +9,60 @@ Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD)
### Fixed

### Added
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)

### Changed

### Removed
- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120)

## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28

### Changed
- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153)

### Fixed

- Update `n_pixel` used by datashader to better adapt across resolutions #152

- Fixed bug in power spectra plotting for the n320 resolution.

- Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165)


### Added

- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155)
- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76)
- Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159)

### Changed

## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14

### Changed

- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111)

### Fixed

- Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138)

- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)

- Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115)
- Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119)

- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)

- Enable longer validation rollout than training

- Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91)

- Save entire config in mlflow


### Added

- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)

- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102)

- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)

- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137)

- Add without subsetting in ScaleTensor

- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63)

- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92)

- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/)

- Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65)

- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/)

- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133)
- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136)

- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147)


Expand All @@ -89,6 +71,9 @@ Keep it human-readable, your future self will thank you!
- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118)
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)
- Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135)
- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111)
JesperDramsch marked this conversation as resolved.
Show resolved Hide resolved

### Removed

## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28

Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
format: zarr
resolution: o96
# Time frequency requested from dataset
frequency: 6h
# Time step of model (must be multiple of frequency)
Expand Down Expand Up @@ -82,5 +81,5 @@ processors:
# _convert_: all
# config: ${data.remapper}

# Values set in the code
# Values set in the code
num_features: null # number of features in the forecast state
9 changes: 1 addition & 8 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def __init__(self, config: DictConfig) -> None:
if not self.config.dataloader.get("pin_memory", True):
LOGGER.info("Data loader memory pinning disabled.")

def _check_resolution(self, resolution: str) -> None:
assert (
self.config.data.resolution.lower() == resolution.lower()
), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}"

@cached_property
def statistics(self) -> dict:
return self.ds_train.statistics
Expand Down Expand Up @@ -153,16 +148,14 @@ def _get_dataset(
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data = NativeGridDataset(
return NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
shuffle=shuffle,
label=label,
)
self._check_resolution(data.resolution)
return data

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
assert stage in {"training", "validation", "test"}
Expand Down
9 changes: 8 additions & 1 deletion src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder

if TYPE_CHECKING:
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -196,7 +197,13 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
trainer_callbacks.extend(_get_config_enabled_callbacks(config))

# Parent UUID callback
trainer_callbacks.append(ParentUUIDCallback(config))
# Check variable order callback
trainer_callbacks.extend(
(
ParentUUIDCallback(config),
CheckVariableOrder(),
),
)

return trainer_callbacks

Expand Down
169 changes: 169 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import pytorch_lightning as pl

LOGGER = logging.getLogger(__name__)


class CheckVariableOrder(pl.callbacks.Callback):
"""Check the order of the variables in a pre-trained / fine-tuning model."""

def __init__(self) -> None:
super().__init__()
self._model_name_to_index = None

def on_load_checkpoint(self, trainer: pl.Trainer, _: pl.LightningModule, checkpoint: dict) -> None:
"""Cache the model mapping from the checkpoint.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
checkpoint : dict
Pytorch Lightning checkpoint
"""
self._model_name_to_index = checkpoint["hyper_parameters"]["data_indices"].name_to_index
data_name_to_index = trainer.datamodule.data_indices.name_to_index

self._compare_variables(data_name_to_index)

def on_sanity_check_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Cache the model mapping from the datamodule if not loaded from checkpoint.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
if self._model_name_to_index is None:
self._model_name_to_index = trainer.datamodule.data_indices.name_to_index

def on_train_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the training data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_train.name_to_index

self._compare_variables(data_name_to_index)

def on_validation_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the validation data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_valid.name_to_index

self._compare_variables(data_name_to_index)

def on_test_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the test data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_test.name_to_index

self._compare_variables(data_name_to_index)

def _compare_variables(self, data_name_to_index: dict[str, int]) -> None:
"""Compare the order of the variables in the model from checkpoint and the data.

Parameters
----------
data_name_to_index : dict[str, int]
The dictionary mapping variable names to their indices in the data.

Raises
------
ValueError
If the variable order in the model and data is verifiably different.
"""
if self._model_name_to_index is None:
LOGGER.info("No variable order to compare. Skipping variable order check.")
return

if self._model_name_to_index == data_name_to_index:
LOGGER.info("The order of the variables in the model matches the order in the data.")
LOGGER.debug("%s, %s", self._model_name_to_index, data_name_to_index)
return

keys1 = set(self._model_name_to_index.keys())
keys2 = set(data_name_to_index.keys())

error_msg = ""

# Find keys unique to each dictionary
only_in_model = {key: self._model_name_to_index[key] for key in (keys1 - keys2)}
only_in_data = {key: data_name_to_index[key] for key in (keys2 - keys1)}

# Find common keys
common_keys = keys1 & keys2

# Compare values for common keys
different_values = {
k: (self._model_name_to_index[k], data_name_to_index[k])
for k in common_keys
if self._model_name_to_index[k] != data_name_to_index[k]
}

LOGGER.warning(
"The variables in the model do not match the variables in the data. "
"If you're fine-tuning or pre-training, you may have to adjust the "
"variable order and naming in your config.",
)
if only_in_model:
LOGGER.warning("Variables only in model: %s", only_in_model)
if only_in_data:
LOGGER.warning("Variables only in data: %s", only_in_data)
if set(only_in_model.values()) == set(only_in_data.values()):
# This checks if the order is the same, but the naming is different. This is not be treated as an error.
LOGGER.warning(
"The variable naming is different, but the order appears to be the same. Continuing with training.",
)
else:
# If the renamed variables are not in the same index locations, raise an error.
error_msg += (
"The variable order in the model and data is different.\n"
"Please adjust the variable order in your config, you may need to "
"use the 'reorder' and 'rename' key in the dataloader config.\n"
"Refer to the Anemoi Datasets documentation for more information.\n"
)
if different_values:
# If the variables are named the same but in different order, raise an error.
error_msg += (
f"Detected a different sort order of the same variables: {different_values}.\n"
"Please adjust the variable order in your config, you may need to use the "
f"'reorder' key in the dataloader config. With:\n `reorder: {self._model_name_to_index}`\n"
)

if error_msg:
LOGGER.error(error_msg)
raise ValueError(error_msg)
Loading
Loading