Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Update sanity checks for training data consistency (#120)
Browse files Browse the repository at this point in the history
* fix: remove resolution check

* feat: first implementation of Callback to check variable order in pre-training

* feat: add variable order checks for pre-training and current training

* tests: implement tests for variable order

* docs: changelog

* tests: make variable for number of fixed callbacks

* refactor: remove nested if as per review

* fix: remove resolution from config

* Fix linting issues

---------

Co-authored-by: Harrison Cook <[email protected]>
  • Loading branch information
JesperDramsch and HCookie authored Dec 3, 2024
1 parent 460b604 commit bb30beb
Show file tree
Hide file tree
Showing 7 changed files with 486 additions and 42 deletions.
39 changes: 12 additions & 27 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,78 +9,60 @@ Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD)
### Fixed

### Added
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)

### Changed

### Removed
- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120)

## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28

### Changed
- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153)

### Fixed

- Update `n_pixel` used by datashader to better adapt across resolutions #152

- Fixed bug in power spectra plotting for the n320 resolution.

- Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165)


### Added

- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155)
- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76)
- Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159)

### Changed

## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14

### Changed

- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111)

### Fixed

- Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138)

- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)

- Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115)
- Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119)

- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)

- Enable longer validation rollout than training

- Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91)

- Save entire config in mlflow


### Added

- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)

- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102)

- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)

- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137)

- Add without subsetting in ScaleTensor

- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63)

- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92)

- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/)

- Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65)

- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/)

- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133)
- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136)

- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147)


Expand All @@ -89,6 +71,9 @@ Keep it human-readable, your future self will thank you!
- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118)
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)
- Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135)
- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111)

### Removed

## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28

Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
format: zarr
resolution: o96
# Time frequency requested from dataset
frequency: 6h
# Time step of model (must be multiple of frequency)
Expand Down Expand Up @@ -82,5 +81,5 @@ processors:
# _convert_: all
# config: ${data.remapper}

# Values set in the code
# Values set in the code
num_features: null # number of features in the forecast state
9 changes: 1 addition & 8 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def __init__(self, config: DictConfig) -> None:
if not self.config.dataloader.get("pin_memory", True):
LOGGER.info("Data loader memory pinning disabled.")

def _check_resolution(self, resolution: str) -> None:
assert (
self.config.data.resolution.lower() == resolution.lower()
), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}"

@cached_property
def statistics(self) -> dict:
return self.ds_train.statistics
Expand Down Expand Up @@ -153,16 +148,14 @@ def _get_dataset(
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data = NativeGridDataset(
return NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
shuffle=shuffle,
label=label,
)
self._check_resolution(data.resolution)
return data

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
assert stage in {"training", "validation", "test"}
Expand Down
9 changes: 8 additions & 1 deletion src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder

if TYPE_CHECKING:
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -196,7 +197,13 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
trainer_callbacks.extend(_get_config_enabled_callbacks(config))

# Parent UUID callback
trainer_callbacks.append(ParentUUIDCallback(config))
# Check variable order callback
trainer_callbacks.extend(
(
ParentUUIDCallback(config),
CheckVariableOrder(),
),
)

return trainer_callbacks

Expand Down
169 changes: 169 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import pytorch_lightning as pl

LOGGER = logging.getLogger(__name__)


class CheckVariableOrder(pl.callbacks.Callback):
"""Check the order of the variables in a pre-trained / fine-tuning model."""

def __init__(self) -> None:
super().__init__()
self._model_name_to_index = None

def on_load_checkpoint(self, trainer: pl.Trainer, _: pl.LightningModule, checkpoint: dict) -> None:
"""Cache the model mapping from the checkpoint.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
checkpoint : dict
Pytorch Lightning checkpoint
"""
self._model_name_to_index = checkpoint["hyper_parameters"]["data_indices"].name_to_index
data_name_to_index = trainer.datamodule.data_indices.name_to_index

self._compare_variables(data_name_to_index)

def on_sanity_check_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Cache the model mapping from the datamodule if not loaded from checkpoint.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
if self._model_name_to_index is None:
self._model_name_to_index = trainer.datamodule.data_indices.name_to_index

def on_train_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the training data.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_train.name_to_index

self._compare_variables(data_name_to_index)

def on_validation_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the validation data.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_valid.name_to_index

self._compare_variables(data_name_to_index)

def on_test_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the test data.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_test.name_to_index

self._compare_variables(data_name_to_index)

def _compare_variables(self, data_name_to_index: dict[str, int]) -> None:
"""Compare the order of the variables in the model from checkpoint and the data.
Parameters
----------
data_name_to_index : dict[str, int]
The dictionary mapping variable names to their indices in the data.
Raises
------
ValueError
If the variable order in the model and data is verifiably different.
"""
if self._model_name_to_index is None:
LOGGER.info("No variable order to compare. Skipping variable order check.")
return

if self._model_name_to_index == data_name_to_index:
LOGGER.info("The order of the variables in the model matches the order in the data.")
LOGGER.debug("%s, %s", self._model_name_to_index, data_name_to_index)
return

keys1 = set(self._model_name_to_index.keys())
keys2 = set(data_name_to_index.keys())

error_msg = ""

# Find keys unique to each dictionary
only_in_model = {key: self._model_name_to_index[key] for key in (keys1 - keys2)}
only_in_data = {key: data_name_to_index[key] for key in (keys2 - keys1)}

# Find common keys
common_keys = keys1 & keys2

# Compare values for common keys
different_values = {
k: (self._model_name_to_index[k], data_name_to_index[k])
for k in common_keys
if self._model_name_to_index[k] != data_name_to_index[k]
}

LOGGER.warning(
"The variables in the model do not match the variables in the data. "
"If you're fine-tuning or pre-training, you may have to adjust the "
"variable order and naming in your config.",
)
if only_in_model:
LOGGER.warning("Variables only in model: %s", only_in_model)
if only_in_data:
LOGGER.warning("Variables only in data: %s", only_in_data)
if set(only_in_model.values()) == set(only_in_data.values()):
# This checks if the order is the same, but the naming is different. This is not be treated as an error.
LOGGER.warning(
"The variable naming is different, but the order appears to be the same. Continuing with training.",
)
else:
# If the renamed variables are not in the same index locations, raise an error.
error_msg += (
"The variable order in the model and data is different.\n"
"Please adjust the variable order in your config, you may need to "
"use the 'reorder' and 'rename' key in the dataloader config.\n"
"Refer to the Anemoi Datasets documentation for more information.\n"
)
if different_values:
# If the variables are named the same but in different order, raise an error.
error_msg += (
f"Detected a different sort order of the same variables: {different_values}.\n"
"Please adjust the variable order in your config, you may need to use the "
f"'reorder' key in the dataloader config. With:\n `reorder: {self._model_name_to_index}`\n"
)

if error_msg:
LOGGER.error(error_msg)
raise ValueError(error_msg)
Loading

0 comments on commit bb30beb

Please sign in to comment.