diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a2c619f5..74bdac0a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,6 +1,6 @@ # CODEOWNERS file # Protect workflow files -/.github/ @theissenhelen @jesperdramsch @gmertes -/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes -/pyproject.toml @theissenhelen @jesperdramsch @gmertes +/.github/ @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry +/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry +/pyproject.toml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb @anaprietonem @HCookie @JPXKQX @mchantry diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index db316803..1db47a1b 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -15,7 +15,12 @@ jobs: skip-hooks: "no-commit-to-branch" checks: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} deploy: needs: [checks, quality] diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml index 3488f55c..cef24795 100644 --- a/.github/workflows/python-pull-request.yml +++ b/.github/workflows/python-pull-request.yml @@ -16,4 +16,9 @@ jobs: skip-hooks: "no-commit-to-branch" checks: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3c39623..e01d6a37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ repos: - id: clear-notebooks-output name: clear-notebooks-output files: tools/.*\.ipynb$ - stages: [commit] + stages: [pre-commit] language: python entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace additional_dependencies: [jupyter] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-yaml # Check YAML files for syntax errors only args: [--unsafe, --allow-multiple-documents] @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.9 hooks: - id: ruff # Next line if for documenation cod snippets @@ -66,11 +66,11 @@ repos: - id: docconvert args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.3" + rev: "2.2.4" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig - rev: v0.60.1 + rev: v0.64.0 hooks: - id: docsig args: diff --git a/CHANGELOG.md b/CHANGELOG.md index d6e664c2..4bd6b142 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,44 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.1.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.2...HEAD) + +## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 + +### Changed + +- Lock python version <3.13 [#107](https://github.com/ecmwf/anemoi-training/pull/107) + +## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24 + +### Added + +- Mlflow-sync to include new tag for server to server syncing [#83](https://github.com/ecmwf/anemoi-training/pull/83) +- Mlflow-sync to include functionality to resume and fork server2server runs [#83](https://github.com/ecmwf/anemoi-training/pull/83) +- Rollout training for Limited Area Models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) +- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) + +### Fixed + +- Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83) +- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99) +- ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100) +- Fix `__version__` import in init + +### Changed + +- Update copyright notice + +## [0.2.0 - Feature release](https://github.com/ecmwf/anemoi-training/compare/0.1.0...0.2.0) - 2024-10-16 + +- Make pin_memory of the Dataloader configurable (#64) ### Added + +- Add anemoi-transform link to documentation - Codeowners file (#56) - Changelog merge strategy (#56) +- Contributors file (#106) #### Miscellaneous @@ -26,7 +59,9 @@ Keep it human-readable, your future self will thank you! - Enforce same binning for histograms comparing true data to predicted data - Fix: Inference checkpoints are now saved according the frequency settings defined in the config [#37](https://github.com/ecmwf/anemoi-training/pull/37) - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) +- Feature: Authentication support for mlflow sync - [#51](https://github.com/ecmwf/anemoi-training/pull/51) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) +- Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots ### Fixed @@ -35,10 +70,14 @@ Keep it human-readable, your future self will thank you! - Bugfixes for CI (#56) - Fix `mlflow` subcommand on python 3.9 [#62](https://github.com/ecmwf/anemoi-training/pull/62) - Show correct subcommand in MLFlow - Addresses [#39](https://github.com/ecmwf/anemoi-training/issues/39) in [#61](https://github.com/ecmwf/anemoi-training/pull/61) +- Fix interactive multi-GPU training [#82](https://github.com/ecmwf/anemoi-training/pull/82) +- Allow 500 characters in mlflow logging [#88](https://github.com/ecmwf/anemoi-training/pull/88) ### Changed - Updated configuration examples in documentation and corrected links - [#46](https://github.com/ecmwf/anemoi-training/pull/46) +- Remove credential prompt from mlflow login, replace with seed refresh token via web - [#78](https://github.com/ecmwf/anemoi-training/pull/78) +- Update CODEOWNERS ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16 @@ -52,6 +91,7 @@ Keep it human-readable, your future self will thank you! - Subcommand for checkpoint handling #### Functionality + - Searchpaths for Hydra configs, to enable configs in CWD, `ANEMOI_CONFIG_PATH` env, and `.config/anemoi/training` in addition to package defaults - MlFlow token authentication - Configurable pressure level scaling diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 00000000..72541b37 --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,13 @@ +## How to Contribute + +Please see the [read the docs](https://anemoi-training.readthedocs.io/en/latest/dev/contributing.html). + + +## Contributors + +Thank you to all the wonderful people who have contributed to Anemoi. Contributions can come in many forms, including code, documentation, bug reports, feature suggestions, design, and more. A list of code-based contributors can be found [here](https://github.com/ecmwf/anemoi-training/graphs/contributors). + + +## Contributing Organisations + +Significant contributions have been made by the following organisations: [DWD](https://www.dwd.de/), [MET Norway](https://www.met.no/), [MeteoSwiss](https://www.meteoswiss.admin.ch/), [RMI](https://www.meteo.be/) & [ECMWF](https://www.ecmwf.int/) diff --git a/README.md b/README.md index 5ee7892f..e51c7697 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ $ pip install anemoi-training ## License ``` -Copyright 2022, European Centre for Medium Range Weather Forecasts. +Copyright 2024, Anemoi contributors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/docs/conf.py b/docs/conf.py index 294ed98f..12e25dd5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,6 +101,10 @@ "https://anemoi-registry.readthedocs.io/en/latest/", ("../../anemoi-registry/docs/_build/html/objects.inv", None), ), + "anemoi-transform": ( + "https://anemoi-transform.readthedocs.io/en/latest/", + ("../../anemoi-transform/docs/_build/html/objects.inv", None), + ), } # -- Options for HTML output ------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 9ca3bbe5..bfbd2dbf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -67,6 +67,7 @@ This package provides the *Anemoi* training functionality. ***************** - :ref:`anemoi-utils ` +- :ref:`anemoi-transform ` - :ref:`anemoi-datasets ` - :ref:`anemoi-models ` - :ref:`anemoi-graphs ` diff --git a/pyproject.toml b/pyproject.toml index 2baeea22..9ffa2db1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ authors = [ { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" }, ] -requires-python = ">=3.9" +requires-python = ">=3.9,<3.13" # Unable to use 3.13 until pyshtools updates classifiers = [ "Development Status :: 4 - Beta", @@ -83,6 +83,10 @@ urls.Documentation = "https://anemoi-training.readthedocs.io/" urls.Homepage = "https://github.com/ecmwf/anemoi-training/" urls.Issues = "https://github.com/ecmwf/anemoi-training/issues" urls.Repository = "https://github.com/ecmwf/anemoi-training/" +# command for interactive DDP (not supposed to be used directly) +# the dot is intentional, so it doesn't trigger autocomplete +scripts.".anemoi-training-train" = "anemoi.training.commands.train:main" + # Add subcommand in the `commands` directory scripts.anemoi-training = "anemoi.training.__main__:main" diff --git a/src/anemoi/training/__init__.py b/src/anemoi/training/__init__.py index d9a51e0b..9733be26 100644 --- a/src/anemoi/training/__init__.py +++ b/src/anemoi/training/__init__.py @@ -6,4 +6,10 @@ # nor does it submit to any jurisdiction. -from ._version import __version__ # noqa: F401 +try: + # NOTE: the `_version.py` file must not be present in the git repository + # as it is generated by setuptools at install time + from ._version import __version__ # type: ignore +except ImportError: # pragma: no cover + # Local copy or not installed with setuptools + __version__ = "999" diff --git a/src/anemoi/training/commands/checkpoint.py b/src/anemoi/training/commands/checkpoint.py index af1aa296..82c95e9e 100644 --- a/src/anemoi/training/commands/checkpoint.py +++ b/src/anemoi/training/commands/checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import argparse import logging diff --git a/src/anemoi/training/commands/config.py b/src/anemoi/training/commands/config.py index e79602d6..221d76dd 100644 --- a/src/anemoi/training/commands/config.py +++ b/src/anemoi/training/commands/config.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import importlib.resources as pkg_resources diff --git a/src/anemoi/training/commands/mlflow.py b/src/anemoi/training/commands/mlflow.py index 4b7e9e86..545b8a10 100644 --- a/src/anemoi/training/commands/mlflow.py +++ b/src/anemoi/training/commands/mlflow.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import argparse from anemoi.training.commands import Command @@ -45,6 +48,7 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None: "--source", "-s", help="The MLflow logs source directory.", + metavar="DIR", required=True, default=argparse.SUPPRESS, ) @@ -52,16 +56,31 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None: "--destination", "-d", help="The destination MLflow tracking URI.", + metavar="URI", + required=True, + default=argparse.SUPPRESS, + ) + sync.add_argument( + "--run-id", + "-r", + help="The run ID to sync.", + metavar="ID", required=True, default=argparse.SUPPRESS, ) - sync.add_argument("--run-id", "-r", help="The run ID to sync.", required=True, default=argparse.SUPPRESS) sync.add_argument( "--experiment-name", "-e", help="The experiment name to sync to.", + metavar="NAME", default="anemoi-debug", ) + sync.add_argument( + "--authentication", + "-a", + action="store_true", + help="The destination server requires authentication.", + ) sync.add_argument( "--export-deleted-runs", "-x", @@ -88,8 +107,18 @@ def run(args: argparse.Namespace) -> None: return if args.subcommand == "sync": + from anemoi.training.diagnostics.mlflow.utils import health_check from anemoi.training.utils.mlflow_sync import MlFlowSync + if args.authentication: + from anemoi.training.diagnostics.mlflow.auth import TokenAuth + + auth = TokenAuth(url=args.destination) + auth.login() + auth.authenticate() + + health_check(args.destination) + log_level = "DEBUG" if args.verbose else "INFO" MlFlowSync( diff --git a/src/anemoi/training/commands/train.py b/src/anemoi/training/commands/train.py index 46052b90..44ce186f 100644 --- a/src/anemoi/training/commands/train.py +++ b/src/anemoi/training/commands/train.py @@ -1,15 +1,19 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging +import os import sys +from pathlib import Path from typing import TYPE_CHECKING from anemoi.training.commands import Command @@ -30,7 +34,8 @@ def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: - + # This will be picked up by the logger + os.environ["ANEMOI_TRAINING_CMD"] = f"{sys.argv[0]} {args.command}" # Merge the known subcommands with a non-whitespace character for hydra new_sysargv = self._merge_sysargv(args) @@ -40,15 +45,15 @@ def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) - else: sys.argv = [new_sysargv] - # Import and run the training command LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:]) - from anemoi.training.train.train import main as anemoi_train - - anemoi_train() + main() def _merge_sysargv(self, args: argparse.Namespace) -> str: """Merge the sys.argv with the known subcommands to pass to hydra. + This is done for interactive DDP, which will spawn the rank > 0 processes from sys.argv[0] + and for hydra, which ingests sys.argv[1:] + Parameters ---------- args : argparse.Namespace @@ -59,10 +64,26 @@ def _merge_sysargv(self, args: argparse.Namespace) -> str: str Modified sys.argv as string """ - modified_sysargv = f"{sys.argv[0]} {args.command}" + argv = Path(sys.argv[0]) + + # this will turn "/env/bin/anemoi-training train" into "/env/bin/.anemoi-training-train" + # the dot at the beginning is intentional to not interfere with autocomplete + modified_sysargv = argv.with_name(f".{argv.name}-{args.command}") + if hasattr(args, "subcommand"): - modified_sysargv += f" {args.subcommand}" - return modified_sysargv + modified_sysargv += f"-{args.subcommand}" + return str(modified_sysargv) + + +def main() -> None: + # Use the environment variable to check if main is being called from the subcommand, not from the ddp entrypoint + if not os.environ.get("ANEMOI_TRAINING_CMD"): + error = "This entrypoint should not be called directly. Use `anemoi-training train` instead." + raise RuntimeError(error) + + from anemoi.training.train.train import main as anemoi_train + + anemoi_train() command = Train diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 35ed7e35..e6d50801 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,4 +1,5 @@ prefetch_factor: 2 +pin_memory: True num_workers: training: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 1e119892..f64a3091 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os from functools import cached_property @@ -96,6 +99,9 @@ def __init__(self, config: DictConfig) -> None: ) self.config.dataloader.training.end = self.config.dataloader.validation.start - 1 + if not self.config.dataloader.get("pin_memory", True): + LOGGER.info("Data loader memory pinning disabled.") + def _check_resolution(self, resolution: str) -> None: assert ( self.config.data.resolution.lower() == resolution.lower() @@ -185,7 +191,7 @@ def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: num_workers=self.config.dataloader.num_workers[stage], # use of pinned memory can speed up CPU-to-GPU data transfers # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning - pin_memory=True, + pin_memory=self.config.dataloader.get("pin_memory", True), # worker initializer worker_init_fn=worker_init_func, # prefetch batches diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index e2aa12bd..9e368f9c 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index 74ba9c23..83419a88 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging from abc import ABC from abc import abstractmethod diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..cf085eab 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -139,6 +139,13 @@ def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: st if self._executor is not None: self._executor.shutdown(wait=True) + def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: + if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: + # Fill with NaNs values where the mask is False + data[:, :, ~pl_module.output_mask, :] = np.nan + + return data + @abstractmethod @rank_zero_only def _plot( @@ -682,12 +689,16 @@ def _plot( ..., pl_module.data_indices.internal_data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() + data = self.post_processors(input_tensor) output_tensor = self.post_processors( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, - ).numpy() + ) + + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() for rollout_step in range(pl_module.rollout): fig = plot_predicted_multilevel_flat_sample( @@ -776,11 +787,15 @@ def _plot( ..., pl_module.data_indices.internal_data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() + data = self.post_processors(input_tensor) output_tensor = self.post_processors( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, - ).numpy() + ) + + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() for rollout_step in range(pl_module.rollout): if self.config.diagnostics.plot.parameters_histogram is not None: diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 599dad36..4e4a35c1 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging @@ -25,7 +28,6 @@ def get_mlflow_logger(config: DictConfig) -> None: return None from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - from anemoi.training.diagnostics.mlflow.logger import get_mlflow_run_params resumed = config.training.run_id is not None forked = config.training.fork_run_id is not None @@ -39,7 +41,6 @@ def get_mlflow_logger(config: DictConfig) -> None: tracking_uri = save_dir # create directory if it does not exist Path(config.hardware.paths.logs.mlflow).mkdir(parents=True, exist_ok=True) - run_id, run_name, tags = get_mlflow_run_params(config, tracking_uri) log_hyperparams = True if resumed and not config.diagnostics.log.mlflow.on_resume_create_child: @@ -53,19 +54,22 @@ def get_mlflow_logger(config: DictConfig) -> None: ) log_hyperparams = False + LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) logger = AnemoiMLflowLogger( experiment_name=config.diagnostics.log.mlflow.experiment_name, + project_name=config.diagnostics.log.mlflow.project_name, tracking_uri=tracking_uri, save_dir=save_dir, - run_name=run_name, - run_id=run_id, + run_name=config.diagnostics.log.mlflow.run_name, + run_id=config.training.run_id, + fork_run_id=config.training.fork_run_id, log_model=config.diagnostics.log.mlflow.log_model, offline=offline, - tags=tags, resumed=resumed, forked=forked, log_hyperparams=log_hyperparams, authentication=config.diagnostics.log.mlflow.authentication, + on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child, ) config_params = OmegaConf.to_container(config, resolve=True) diff --git a/src/anemoi/training/diagnostics/maps.py b/src/anemoi/training/diagnostics/maps.py index 265d814b..338a9059 100644 --- a/src/anemoi/training/diagnostics/maps.py +++ b/src/anemoi/training/diagnostics/maps.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import copy import json import logging diff --git a/src/anemoi/training/diagnostics/mlflow/auth.py b/src/anemoi/training/diagnostics/mlflow/auth.py index 9c34ffe0..144a967e 100644 --- a/src/anemoi/training/diagnostics/mlflow/auth.py +++ b/src/anemoi/training/diagnostics/mlflow/auth.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging @@ -30,11 +33,7 @@ class TokenAuth: config_file = "mlflow-token.json" - def __init__( - self, - url: str, - enabled: bool = True, - ) -> None: + def __init__(self, url: str, enabled: bool = True, target_env_var: str = "MLFLOW_TRACKING_TOKEN") -> None: """Initialise the token authentication object. Parameters @@ -43,9 +42,13 @@ def __init__( URL of the authentication server. enabled : bool, optional Set this to False to turn off authentication, by default True + target_env_var : str, optional + The environment variable to store the access token in after authenticating, + by default `MLFLOW_TRACKING_TOKEN` """ self.url = url + self.target_env_var = target_env_var self._enabled = enabled config = self.load_config() @@ -91,14 +94,16 @@ def login(self, force_credentials: bool = False, **kwargs: dict) -> None: """Acquire a new refresh token and save it to disk. If an existing valid refresh token is already on disk it will be used. - If not, or the token has expired, the user will be prompted for credentials. + If not, or the token has expired, the user will be asked to obtain one from the API. + + Refresh token expiry time is set in the `REFRESH_EXPIRE_DAYS` constant (default 29 days). This function should be called once, interactively, right before starting a training run. Parameters ---------- force_credentials : bool, optional - Force a username/password prompt even if a refreh token is available, by default False. + Force a credential login even if a refreh token is available, by default False. kwargs : dict Additional keyword arguments. @@ -116,11 +121,12 @@ def login(self, force_credentials: bool = False, **kwargs: dict) -> None: new_refresh_token = self._token_request(ignore_exc=True).get("refresh_token") if not new_refresh_token: - self.log.info("📝 Please sign in with your credentials.") - username = input("Username: ") - password = getpass("Password: ") + self.log.info("📝 Please obtain a seed refresh token from %s/seed", self.url) + self.log.info("📝 and paste it here (you will not see the output, just press enter after pasting):") + self.refresh_token = getpass("Refresh Token: ") - new_refresh_token = self._token_request(username=username, password=password).get("refresh_token") + # perform a new refresh token request to check if the seed refresh token is valid + new_refresh_token = self._token_request().get("refresh_token") if not new_refresh_token: msg = "❌ Failed to log in. Please try again." @@ -133,9 +139,11 @@ def login(self, force_credentials: bool = False, **kwargs: dict) -> None: @enabled def authenticate(self, **kwargs: dict) -> None: - """Check the access token and refresh it if necessary. + """Check the access token and refresh it if necessary. A new refresh token will also be acquired upon refresh. + + This requires a valid refresh token to be available, obtained from the `login` method. - The access token is stored in memory and in the environment variable `MLFLOW_TRACKING_TOKEN`. + The access token is stored in memory and in an environment variable. If the access token is still valid, this function does nothing. This function should be called before every MLflow API request. @@ -161,7 +169,7 @@ def authenticate(self, **kwargs: dict) -> None: self.access_expires = time.time() + (response.get("expires_in") * 0.7) # bit of buffer self.refresh_token = response.get("refresh_token") - os.environ["MLFLOW_TRACKING_TOKEN"] = self.access_token + os.environ[self.target_env_var] = self.access_token @enabled def save(self, **kwargs: dict) -> None: @@ -183,16 +191,10 @@ def save(self, **kwargs: dict) -> None: def _token_request( self, - username: str | None = None, - password: str | None = None, ignore_exc: bool = False, ) -> dict: - if username is not None and password is not None: - path = "newtoken" - payload = {"username": username, "password": password} - else: - path = "refreshtoken" - payload = {"refresh_token": self.refresh_token} + path = "refreshtoken" + payload = {"refresh_token": self.refresh_token} try: response = self._request(path, payload) diff --git a/src/anemoi/training/diagnostics/mlflow/client.py b/src/anemoi/training/diagnostics/mlflow/client.py new file mode 100644 index 00000000..2d97cb4b --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/client.py @@ -0,0 +1,59 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +from typing import Any + +from mlflow import MlflowClient + +from anemoi.training.diagnostics.mlflow.auth import TokenAuth +from anemoi.training.diagnostics.mlflow.utils import health_check + + +class AnemoiMlflowClient(MlflowClient): + """Anemoi extension of the MLflow client with token authentication support.""" + + def __init__( + self, + tracking_uri: str, + *args, + authentication: bool = False, + check_health: bool = True, + **kwargs, + ) -> None: + """Behaves like a normal `mlflow.MlflowClient` but with token authentication injected on every call. + + Parameters + ---------- + tracking_uri : str + The URI of the MLflow tracking server. + authentication : bool, optional + Enable token authentication, by default False + check_health : bool, optional + Check the health of the MLflow server on init, by default True + *args : Any + Additional arguments to pass to the MLflow client. + **kwargs : Any + Additional keyword arguments to pass to the MLflow client. + + """ + self.anemoi_auth = TokenAuth(tracking_uri, enabled=authentication) + if check_health: + super().__getattribute__("anemoi_auth").authenticate() + health_check(tracking_uri) + super().__init__(tracking_uri, *args, **kwargs) + + def __getattribute__(self, name: str) -> Any: + """Intercept attribute access and inject authentication.""" + attr = super().__getattribute__(name) + if callable(attr) and name != "anemoi_auth": + super().__getattribute__("anemoi_auth").authenticate() + return attr diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7854c172..183e7a0d 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import io @@ -20,90 +23,24 @@ from typing import Literal from weakref import WeakValueDictionary -import requests +from packaging.version import Version from pytorch_lightning.loggers.mlflow import MLFlowLogger from pytorch_lightning.loggers.mlflow import _convert_params from pytorch_lightning.loggers.mlflow import _flatten_dict from pytorch_lightning.utilities.rank_zero import rank_zero_only from anemoi.training.diagnostics.mlflow.auth import TokenAuth +from anemoi.training.diagnostics.mlflow.utils import health_check from anemoi.training.utils.jsonify import map_config_to_primitives if TYPE_CHECKING: from argparse import Namespace - from omegaconf import OmegaConf + import mlflow LOGGER = logging.getLogger(__name__) -def health_check(tracking_uri: str) -> None: - """Query the health endpoint of an MLflow server. - - If the server is not reachable, raise an error and remind the user that authentication may be required. - - Raises - ------ - ConnectionError - If the server is not reachable. - - """ - token = os.getenv("MLFLOW_TRACKING_TOKEN") - - headers = {"Authorization": f"Bearer {token}"} - response = requests.get(f"{tracking_uri}/health", headers=headers, timeout=60) - - if response.text == "OK": - return - - error_msg = f"Could not connect to MLflow server at {tracking_uri}. " - if not token: - error_msg += "The server may require authentication, did you forget to turn it on in the config?" - raise ConnectionError(error_msg) - - -def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | None, str, dict[str, Any]]: - run_id = None - tags = {"projectName": config.diagnostics.log.mlflow.project_name} - # create a tag with the command used to run the script - tags["command"] = sys.argv[0].split("/")[-1] # get the python script name - if len(sys.argv) > 1: - # add the arguments to the command tag - tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) - if config.training.run_id or config.training.fork_run_id: - "Either run_id or fork_run_id must be provided to resume a run." - - import mlflow - - if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline: - TokenAuth(tracking_uri).authenticate() - - mlflow_client = mlflow.MlflowClient(tracking_uri) - - if config.training.run_id and config.diagnostics.log.mlflow.on_resume_create_child: - parent_run_id = config.training.run_id # parent_run_id - run_name = mlflow_client.get_run(parent_run_id).info.run_name - tags["mlflow.parentRunId"] = parent_run_id - tags["resumedRun"] = "True" # tags can't take boolean values - elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child: - run_id = config.training.run_id - run_name = mlflow_client.get_run(run_id).info.run_name - mlflow_client.update_run(run_id=run_id, status="RUNNING") - tags["resumedRun"] = "True" - else: - parent_run_id = config.training.fork_run_id - tags["forkedRun"] = "True" - tags["forkedRunId"] = parent_run_id - - if config.diagnostics.log.mlflow.run_name: - run_name = config.diagnostics.log.mlflow.run_name - else: - import uuid - - run_name = f"{uuid.uuid4()!s}" - return run_id, run_name, tags - - class LogsMonitor: """Class for logging terminal output. @@ -112,7 +49,7 @@ class LogsMonitor: Note: If there is an error, the terminal output logging ends before the error message is printed into the log file. In order for the user to see the error message, the user must look at the slurm output file. - We provide the SLRM job id in the very beginning of the log file and print the final status of the run in the end. + We provide the SLURM job id in the very beginning of the log file and print the final status of the run in the end. Parameters ---------- @@ -213,7 +150,7 @@ def start(self) -> None: self._buffer_registry[id(self)] = self._io_buffer # Start thread to asynchronously collect logs self._th_collector.start() - LOGGER.info("Termial Log Path: %s", self.file_save_path) + LOGGER.info("Terminal Log Path: %s", self.file_save_path) if os.getenv("SLURM_JOB_ID"): LOGGER.info("SLURM job id: %s", os.getenv("SLURM_JOB_ID")) @@ -310,18 +247,20 @@ class AnemoiMLflowLogger(MLFlowLogger): def __init__( self, experiment_name: str = "lightning_logs", + project_name: str = "anemoi", run_name: str | None = None, tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"), - tags: dict[str, Any] | None = None, save_dir: str | None = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", resumed: bool | None = False, forked: bool | None = False, run_id: str | None = None, + fork_run_id: str | None = None, offline: bool | None = False, authentication: bool | None = None, log_hyperparams: bool | None = True, + on_resume_create_child: bool | None = True, ) -> None: """Initialize the AnemoiMLflowLogger. @@ -329,12 +268,12 @@ def __init__( ---------- experiment_name : str, optional Name of experiment, by default "lightning_logs" + project_name : str, optional + Name of the project, by default "anemoi" run_name : str | None, optional Name of run, by default None tracking_uri : str | None, optional Tracking URI of server, by default os.getenv("MLFLOW_TRACKING_URI") - tags : dict[str, Any] | None, optional - Tags to apply, by default None save_dir : str | None, optional Directory to save logs to, by default "./mlruns" log_model : Literal[True, False, "all"], optional @@ -347,31 +286,28 @@ def __init__( Whether the run was forked or not, by default False run_id : str | None, optional Run id of current run, by default None + fork_run_id : str | None, optional + Fork Run id from parent run, by default None offline : bool | None, optional Whether to run offline or not, by default False authentication : bool | None, optional Whether to authenticate with server or not, by default None log_hyperparams : bool | None, optional Whether to log hyperparameters, by default True - + on_resume_create_child: bool | None, optional + Whether to create a child run when resuming a run, by default False """ - if offline: - # OFFLINE - When we run offline we can pass a save_dir pointing to a local path - tracking_uri = None - - else: - # ONLINE - When we pass a tracking_uri to mlflow then it will ignore the - # saving dir and save all artifacts/metrics to the remote server database - save_dir = None - self._resumed = resumed self._forked = forked self._flag_log_hparams = log_hyperparams - if rank_zero_only.rank == 0: - enabled = authentication and not offline - self.auth = TokenAuth(tracking_uri, enabled=enabled) + self._fork_run_server2server = None + self._parent_run_server2server = None + enabled = authentication and not offline + self.auth = TokenAuth(tracking_uri, enabled=enabled) + + if rank_zero_only.rank == 0: if offline: LOGGER.info("MLflow is logging offline.") else: @@ -379,6 +315,24 @@ def __init__( self.auth.authenticate() health_check(tracking_uri) + run_id, run_name, tags = self._get_mlflow_run_params( + project_name=project_name, + run_name=run_name, + config_run_id=run_id, + fork_run_id=fork_run_id, + tracking_uri=tracking_uri, + on_resume_create_child=on_resume_create_child, + ) + # Before creating the run we need to overwrite the tracking_uri and save_dir if offline + if offline: + # OFFLINE - When we run offline we can pass a save_dir pointing to a local path + tracking_uri = None + + else: + # ONLINE - When we pass a tracking_uri to mlflow then it will ignore the + # saving dir and save all artifacts/metrics to the remote server database + save_dir = None + super().__init__( experiment_name=experiment_name, run_name=run_name, @@ -390,6 +344,84 @@ def __init__( run_id=run_id, ) + def _check_server2server_lineage(self, run: mlflow.entities.Run) -> bool: + """Address lineage and metadata for server2server runs. + + Those are runs that have been sync from one remote server to another + """ + server2server = run.data.tags.get("server2server", "False") == "True" + LOGGER.info("Server2Server: %s", server2server) + if server2server: + parent_run_across_servers = run.data.params.get( + "metadata.offline_run_id", + run.data.params.get("metadata.server2server_run_id"), + ) + if self._forked: + # if we want to fork a resume run we need to set the parent_run_across_servers + # but just to restore the checkpoint + self._fork_run_server2server = parent_run_across_servers + else: + self._parent_run_server2server = parent_run_across_servers + + def _get_mlflow_run_params( + self, + project_name: str, + run_name: str, + config_run_id: str, + fork_run_id: str, + tracking_uri: str, + on_resume_create_child: bool, + ) -> tuple[str | None, str, dict[str, Any]]: + + run_id = None + tags = {"projectName": project_name} + + # create a tag with the command used to run the script + command = os.environ.get("ANEMOI_TRAINING_CMD", sys.argv[0]) + tags["command"] = command.split("/")[-1] # get the python script name + tags["mlflow.source.name"] = command + if len(sys.argv) > 1: + # add the arguments to the command tag + tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) + + if config_run_id or fork_run_id: + "Either run_id or fork_run_id must be provided to resume a run." + import mlflow + + self.auth.authenticate() + mlflow_client = mlflow.MlflowClient(tracking_uri) + + if config_run_id and on_resume_create_child: + parent_run_id = config_run_id # parent_run_id + parent_run = mlflow_client.get_run(parent_run_id) + run_name = parent_run.info.run_name + self._check_server2server_lineage(parent_run) + tags["mlflow.parentRunId"] = parent_run_id + tags["resumedRun"] = "True" # tags can't take boolean values + elif config_run_id and not on_resume_create_child: + run_id = config_run_id + run = mlflow_client.get_run(run_id) + run_name = run.info.run_name + self._check_server2server_lineage(run) + mlflow_client.update_run(run_id=run_id, status="RUNNING") + tags["resumedRun"] = "True" + else: + parent_run_id = fork_run_id + tags["forkedRun"] = "True" + tags["forkedRunId"] = parent_run_id + run = mlflow_client.get_run(parent_run_id) + self._check_server2server_lineage(run) + + if not run_name: + import uuid + + run_name = f"{uuid.uuid4()!s}" + + if os.getenv("SLURM_JOB_ID"): + tags["SLURM_JOB_ID"] = os.getenv("SLURM_JOB_ID") + + return run_id, run_name, tags + @property def experiment(self) -> MLFlowLogger.experiment: if rank_zero_only.rank == 0: @@ -463,12 +495,14 @@ def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries params = self._clean_params(params) + import mlflow from mlflow.entities import Param - # Truncate parameter values to 250 characters. - # TODO (Ana Prieto Nemesio): MLflow 1.28 allows up to 500 characters: - # https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] + # Truncate parameter values. + truncation_length = 250 + if Version(mlflow.VERSION) >= Version("1.28.0"): + truncation_length = 500 + params_list = [Param(key=k, value=str(v)[:truncation_length]) for k, v in params.items()] for idx in range(0, len(params_list), 100): self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) diff --git a/src/anemoi/training/diagnostics/mlflow/utils.py b/src/anemoi/training/diagnostics/mlflow/utils.py new file mode 100644 index 00000000..89f6e002 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/utils.py @@ -0,0 +1,38 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import os + +import requests + + +def health_check(tracking_uri: str) -> None: + """Query the health endpoint of an MLflow server. + + If the server is not reachable, raise an error and remind the user that authentication may be required. + + Raises + ------ + ConnectionError + If the server is not reachable. + + """ + token = os.getenv("MLFLOW_TRACKING_TOKEN") + + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(f"{tracking_uri}/health", headers=headers, timeout=60) + + if response.text == "OK": + return + + error_msg = f"Could not connect to MLflow server at {tracking_uri}. " + if not token: + error_msg += "The server may require authentication, did you forget to turn it on?" + raise ConnectionError(error_msg) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index b2004cf4..7b4ba711 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c15828ca..c6509795 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index b72a7aea..88ad0d0b 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index 9a866a0a..5ddef3d6 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff1acfd7..36f6fa9a 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + import logging import math @@ -31,6 +32,8 @@ from anemoi.training.losses.mse import WeightedMSELoss from anemoi.training.losses.utils import grad_scaler from anemoi.training.utils.jsonify import map_config_to_primitives +from anemoi.training.utils.masks import Boolean1DMask +from anemoi.training.utils.masks import NoOutputMask LOGGER = logging.getLogger(__name__) @@ -82,6 +85,12 @@ def __init__( self.latlons_data = graph_data[config.graph.data].x self.loss_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + if config.model.get("output_mask", None) is not None: + self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) + else: + self.output_mask = NoOutputMask() + self.loss_weights = self.output_mask.apply(self.loss_weights, dim=0, fill_value=0.0) + self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled self.metric_ranges, self.metric_ranges_validation, loss_scaling = self.metrics_loss_scaling( @@ -202,6 +211,8 @@ def advance_input( self.data_indices.internal_model.output.prognostic, ] + x[:, -1] = self.output_mask.rollout_boundary(x[:, -1], batch[:, -1], self.data_indices) + # get new "constants" needed for time-varying fields x[:, -1, :, :, self.data_indices.internal_model.input.forcing] = batch[ :, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..b772eb2a 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations @@ -68,6 +69,10 @@ def __init__(self, config: DictConfig) -> None: self.config.training.run_id = self.run_id LOGGER.info("Run id: %s", self.config.training.run_id) + + # Get the server2server lineage + self._get_server2server_lineage() + # Update paths to contain the run ID self._update_paths() @@ -147,7 +152,7 @@ def model(self) -> GraphForecaster: @rank_zero_only def _get_mlflow_run_id(self) -> str: run_id = self.mlflow_logger.run_id - # for resumed runs or offline runs logging this can be uesful + # for resumed runs or offline runs logging this can be useful LOGGER.info("Mlflow Run id: %s", run_id) return run_id @@ -188,18 +193,21 @@ def last_checkpoint(self) -> str | None: if not self.start_from_checkpoint: return None + fork_id = self.fork_run_server2server or self.config.training.fork_run_id checkpoint = Path( self.config.hardware.paths.checkpoints.parent, - self.config.training.fork_run_id or self.run_id, + fork_id or self.lineage_run, self.config.hardware.files.warm_start or "last.ckpt", ) - # Check if the last checkpoint exists if Path(checkpoint).exists(): LOGGER.info("Resuming training from last checkpoint: %s", checkpoint) return checkpoint - LOGGER.warning("Could not find last checkpoint: %s", checkpoint) + if rank_zero_only.rank == 0: + msg = "Could not find last checkpoint: %s", checkpoint + raise RuntimeError(msg) + return None @cached_property @@ -252,10 +260,13 @@ def profiler(self) -> PyTorchProfiler | None: def loggers(self) -> list: loggers = [] if self.config.diagnostics.log.wandb.enabled: + LOGGER.info("W&B logger enabled") loggers.append(self.wandb_logger) if self.config.diagnostics.log.tensorboard.enabled: + LOGGER.info("TensorBoard logger enabled") loggers.append(self.tensorboard_logger) if self.config.diagnostics.log.mlflow.enabled: + LOGGER.info("MLFlow logger enabled") loggers.append(self.mlflow_logger) return loggers @@ -291,17 +302,33 @@ def _log_information(self) -> None: LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) + def _get_server2server_lineage(self) -> None: + """Get the server2server lineage.""" + self.parent_run_server2server = None + self.fork_run_server2server = None + if self.config.diagnostics.log.mlflow.enabled: + self.parent_run_server2server = self.mlflow_logger._parent_run_server2server + LOGGER.info("Parent run server2server: %s", self.parent_run_server2server) + self.fork_run_server2server = self.mlflow_logger._fork_run_server2server + LOGGER.info("Fork run server2server: %s", self.fork_run_server2server) + def _update_paths(self) -> None: """Update the paths in the configuration.""" + self.lineage_run = None if self.run_id: # when using mlflow only rank0 will have a run_id except when resuming runs # Multi-gpu new runs or forked runs - only rank 0 # Multi-gpu resumed runs - all ranks - self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.run_id) - self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.run_id) + self.lineage_run = self.parent_run_server2server or self.run_id + self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run) + self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.lineage_run) elif self.config.training.fork_run_id: + # WHEN USING MANY NODES/GPUS + self.lineage_run = self.parent_run_server2server or self.config.training.fork_run_id # Only rank non zero in the forked run will go here - parent_run = self.config.training.fork_run_id - self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, parent_run) + self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run) + + LOGGER.info("Checkpoints path: %s", self.config.hardware.paths.checkpoints) + LOGGER.info("Plots path: %s", self.config.hardware.paths.plots) @cached_property def strategy(self) -> DDPGroupStrategy: diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 697a92de..ddb5a1c8 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations from pathlib import Path diff --git a/src/anemoi/training/utils/jsonify.py b/src/anemoi/training/utils/jsonify.py index ddf9b86c..c44092b1 100644 --- a/src/anemoi/training/utils/jsonify.py +++ b/src/anemoi/training/utils/jsonify.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import datetime from pathlib import Path diff --git a/src/anemoi/training/utils/masks.py b/src/anemoi/training/utils/masks.py new file mode 100644 index 00000000..fd0581a0 --- /dev/null +++ b/src/anemoi/training/utils/masks.py @@ -0,0 +1,115 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from anemoi.models.data_indices.collection import IndexCollection + + +class BaseMask: + """Base class for masking model output.""" + + @abstractmethod + def apply(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + error_message = "Method `apply` must be implemented in subclass." + raise NotImplementedError(error_message) + + @abstractmethod + def rollout_boundary(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + error_message = "Method `rollout_boundary` must be implemented in subclass." + raise NotImplementedError(error_message) + + +class Boolean1DMask(BaseMask): + """1D Boolean mask.""" + + def __init__(self, values: torch.Tensor) -> None: + self.mask = values.bool().squeeze() + + def broadcast_like(self, x: torch.Tensor, dim: int) -> torch.Tensor: + assert x.shape[dim] == len( + self.mask, + ), f"Dimension mismatch: dimension {dim} has size {x.shape[dim]}, but mask length is {len(self.mask)}." + target_shape = [1 for _ in range(x.ndim)] + target_shape[dim] = len(self.mask) + mask = self.mask.reshape(target_shape) + return mask.to(x.device) + + @staticmethod + def _fill_masked_tensor(x: torch.Tensor, mask: torch.Tensor, fill_value: float | torch.Tensor) -> torch.Tensor: + if isinstance(fill_value, torch.Tensor): + return x.masked_scatter(mask, fill_value) + return x.masked_fill(mask, fill_value) + + def apply(self, x: torch.Tensor, dim: int, fill_value: float | torch.Tensor = np.nan) -> torch.Tensor: + """Apply the mask to the input tensor. + + Parameters + ---------- + x : torch.Tensor + The input tensor to be masked. + dim : int + The dimension along which to apply the mask. + fill_value : float | torch.Tensor, optional + The value to fill in the masked positions, by default np.nan. + + Returns + ------- + torch.Tensor + The masked tensor with fill_value in the positions where the mask is False. + """ + mask = self.broadcast_like(x, dim) + return Boolean1DMask._fill_masked_tensor(x, ~mask, fill_value) + + def rollout_boundary( + self, + pred_state: torch.Tensor, + true_state: torch.Tensor, + data_indices: IndexCollection, + ) -> torch.Tensor: + """Rollout the boundary forcing. + + Parameters + ---------- + pred_state : torch.Tensor + The predicted state tensor of shape (bs, ens, latlon, nvar) + true_state : torch.Tensor + The true state tensor of shape (bs, ens, latlon, nvar) + data_indices : IndexCollection + Collection of data indices. + + Returns + ------- + torch.Tensor + The updated predicted state tensor with boundary forcing applied. + """ + pred_state[..., data_indices.model.input.prognostic] = self.apply( + pred_state[..., data_indices.model.input.prognostic], + dim=2, + fill_value=true_state[..., data_indices.data.output.prognostic], + ) + + return pred_state + + +class NoOutputMask(BaseMask): + """No output mask.""" + + def apply(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # noqa: ARG002 + return x + + def rollout_boundary(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # noqa: ARG002 + return x diff --git a/src/anemoi/training/utils/mlflow_sync.py b/src/anemoi/training/utils/mlflow_sync.py index 3ff1adf2..534b4e3f 100644 --- a/src/anemoi/training/utils/mlflow_sync.py +++ b/src/anemoi/training/utils/mlflow_sync.py @@ -1,15 +1,22 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os +import shutil import tempfile from itertools import starmap from pathlib import Path +from urllib.parse import urlparse + +import mlflow.entities def export_log_output_file_path() -> tempfile._TemporaryFileWrapper: @@ -26,11 +33,21 @@ def export_log_output_file_path() -> tempfile._TemporaryFileWrapper: Path(tmpdir).mkdir(parents=True, exist_ok=True) temp = tempfile.NamedTemporaryFile(dir=tmpdir, prefix=f"{user}_") # noqa: SIM115 os.environ["MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE"] = temp.name + os.environ["MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY"] = tmpdir return temp +def close_and_clean_temp(server2server: str, artifact_path: Path) -> None: + temp.close() + os.environ.pop("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE") + os.environ.pop("MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY") + if server2server: + shutil.rmtree(artifact_path) + + temp = export_log_output_file_path() + import mlflow # noqa: E402 from mlflow.entities import RunStatus # noqa: E402 from mlflow.entities import RunTag # noqa: E402 @@ -49,7 +66,7 @@ def export_log_output_file_path() -> tempfile._TemporaryFileWrapper: from mlflow_export_import.run.run_data_importer import _log_metrics from mlflow_export_import.run.run_data_importer import _log_params except ImportError: - msg = "The 'mlflow_export_import' package is not installed. Please install it from https://github.com/mlflow/mlflow-export-import" + msg = "The 'mlflow-export-import' package is not installed. Please install it from https://github.com/mlflow/mlflow-export-import" raise ImportError(msg) from None LOGGER = logging.getLogger(__name__) @@ -110,22 +127,22 @@ def __init__( LOGGER.setLevel(self.log_level) @staticmethod - def update_run_id(params: dict, key: str, new_run_id: str, offline_run_id: str) -> dict: + def update_run_id(params: dict, key: str, new_run_id: str, src_run_id: str, run_type: str) -> dict: params[f"config.training.{key}"] = new_run_id - params[f"config.training.offline_{key}"] = offline_run_id + params[f"config.training.{run_type}_{key}"] = src_run_id if key == "run_id": - params[f"metadata.offline_{key}"] = offline_run_id + params[f"metadata.{run_type}_{key}"] = src_run_id params[f"metadata.{key}"] = new_run_id return params - def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_run_id: str) -> dict: + def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_run_id: str, run_type: str) -> dict: mlflow.set_tracking_uri(self.dest_tracking_uri) # Check if there is already a parent run in the destination tracking uri runs = mlflow.search_runs( experiment_ids=mlflow.get_experiment_by_name(self.experiment_name).experiment_id, - filter_string=f"params.metadata.offline_run_id = '{tags[tag_key]}'", + filter_string=f"params.metadata.{run_type}_run_id = '{tags[tag_key]}'", ) if not runs.empty: @@ -139,19 +156,128 @@ def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_ru tags[tag_key] = new_parent_run_id # update new online parent run_id return tags - def check_run_is_logged(self, status: str = "FINISHED") -> bool: + def check_run_is_logged(self, status: str = "FINISHED", server2server: bool = False) -> bool: """Blocks sync if top-level parent run or single runs are unavailable.""" run_logged = False if status == "FINISHED": mlflow.set_tracking_uri(self.dest_tracking_uri) - synced_runs = mlflow.search_runs( - experiment_ids=mlflow.get_experiment_by_name(self.experiment_name).experiment_id, - filter_string=f"params.metadata.offline_run_id = '{self.run_id}'", - ) - if not synced_runs.empty: # single run (no child) already logged - run_logged = True + experiment = mlflow.get_experiment_by_name(self.experiment_name) + run_type = "server2server" if server2server else "offline" + if experiment: + synced_runs = mlflow.search_runs( + experiment_ids=experiment.experiment_id, + filter_string=f"params.metadata.{run_type}_run_id = '{self.run_id}'", + ) + if not synced_runs.empty: # single run (no child) already logged + run_logged = True return run_logged + def _check_source_tracking_uri(self) -> bool: + parsed_url = urlparse(self.source_tracking_uri) + return all([parsed_url.scheme, parsed_url.netloc]) # True if source_tracking_uri is a remote server + + def _get_dst_experiment_id(self, dest_mlflow_client: str) -> str: + experiment = dest_mlflow_client.get_experiment_by_name(self.experiment_name) + if not experiment: + return dest_mlflow_client.create_experiment(self.experiment_name) + return experiment.experiment_id + + def _get_artifacts_path(self, server2server: str, run: mlflow.entities.Run) -> Path: + if server2server: + # Download each artifact + temp_dir = os.getenv("MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY") + artifact_path = Path(temp_dir, run.info.run_id) + artifact_path.mkdir(parents=True, exist_ok=True) + else: + artifact_path = Path(self.source_tracking_uri, run.info.experiment_id, run.info.run_id, "artifacts") + + return artifact_path + + def _download_artifacts( + self, + client: mlflow.tracking.client.MlflowClient, + run_id: mlflow.entities.Run, + artifact_path: Path, + ) -> None: + + mlflow.set_tracking_uri(self.source_tracking_uri) # OTHERWISE IT WILL NOT WORK + artifacts = client.list_artifacts(run_id) + LOGGER.info("Downloading artifacts %s for run %s to %s", len(artifacts), run_id, artifact_path) + for artifact in artifacts: + # Download artifact file from the server + mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact.path, dst_path=artifact_path) + + def _update_params_tags_runs( + self, + params: dict, + tags: dict, + dst_run_id: str, + src_run_id: str, + run_type: str = "offline", + ) -> (dict, dict): + + if (params["config.training.fork_run_id"] == "None") and (params["metadata.run_id"] == src_run_id): + params = self.update_run_id( + params, + "run_id", + new_run_id=dst_run_id, + src_run_id=src_run_id, + run_type=run_type, + ) + + elif "forkedRun" in tags: + try: + tags = self.update_parent_run_info( + tags=tags, + tag_key="forkedRunId", + tag_dest=f"{run_type}.forkedRunId", + dst_run_id=dst_run_id, + run_type=run_type, + ) + params = self.update_run_id( + params, + "fork_run_id", + new_run_id=tags["forkedRunId"], + src_run_id=tags[f"{run_type}.forkedRunId"], + run_type=run_type, + ) + params = self.update_run_id( + params, + "run_id", + new_run_id=dst_run_id, + src_run_id=src_run_id, + run_type=run_type, + ) + + except AttributeError: + LOGGER.warning("No forked run parent found") + + elif "resumedRun" in tags: + try: + tags = self.update_parent_run_info( + tags=tags, + tag_key="mlflow.parentRunId", + tag_dest=f"mlflow.{run_type}.parentRunId", + dst_run_id=dst_run_id, + run_type=run_type, + ) + params = self.update_run_id( + params, + "run_id", + new_run_id=tags["mlflow.parentRunId"], + src_run_id=tags[f"mlflow.{run_type}.parentRunId"], + run_type=run_type, + ) + + # in the offline case that's the local folder name for the resumed run + # in the server2server case that's the source server run_id of the resumed run + params[f"config.training.{run_type}_self_run_id"] = src_run_id + + except AttributeError: + LOGGER.warning("No parent run found") + + return params, tags + def sync( self, ) -> None: @@ -161,6 +287,7 @@ def sync( http_client = create_http_client(dest_mlflow_client) # GET SOURCE RUN ## run = src_mlflow_client.get_run(self.run_id) + server2server = self._check_source_tracking_uri() run_logged = self.check_run_is_logged(status=run.info.status) if run_logged: LOGGER.info("Run already imported %s into experiment %s", self.run_id, self.experiment_name) @@ -174,6 +301,7 @@ def sync( ) return + msg = { "run_id": run.info.run_id, "lifecycle_stage": run.info.lifecycle_stage, @@ -184,69 +312,44 @@ def sync( run_info = mlflow_utils.strip_underscores(run.info) src_user_id = run_info["user_id"] - exp = dest_mlflow_client.get_experiment_by_name(self.experiment_name) - dst_run = dest_mlflow_client.create_run(exp.experiment_id) + exp_id = self._get_dst_experiment_id(dest_mlflow_client=dest_mlflow_client) + dst_run = dest_mlflow_client.create_run(exp_id) dst_run_id = dst_run.info.run_id tags = dict(sorted(run.data.tags.items())) params = run.data.params # So far there is no easy way to force mlflow to use a specific run_id, that means - # that when we online sync the offline runs those will have run run_ids. To keep + # that when we online sync the offline runs those will have different run_ids. To keep # track of online and offline governance in that case we update run_ids info - if (params["config.training.fork_run_id"] == "None") and (params["metadata.run_id"] == run.info.run_id): - params = self.update_run_id(params, "run_id", new_run_id=dst_run_id, offline_run_id=run.info.run_id) + artifact_path = self._get_artifacts_path(server2server, run) - elif "forkedRun" in tags: - try: - tags = self.update_parent_run_info( - tags=tags, - tag_key="forkedRunId", - tag_dest="offline.forkedRunId", - dst_run_id=dst_run_id, - ) - params = self.update_run_id( - params, - "fork_run_id", - new_run_id=tags["forkedRunId"], - offline_run_id=tags["offline.forkedRunId"], - ) - params = self.update_run_id(params, "run_id", new_run_id=dst_run_id, offline_run_id=run.info.run_id) - - except AttributeError: - LOGGER.warning("No forked run parent found") - - elif "resumedRun" in tags: - try: - tags = self.update_parent_run_info( - tags=tags, - tag_key="mlflow.parentRunId", - tag_dest="mlflow.offline.parentRunId", - dst_run_id=dst_run_id, - ) - params = self.update_run_id( - params, - "run_id", - new_run_id=tags["mlflow.parentRunId"], - offline_run_id=tags["mlflow.offline.parentRunId"], - ) - - params["config.training.offline_run_id_folder"] = run.info.run_id - - except AttributeError: - LOGGER.warning("No parent run found") + if server2server: + tags["server2server"] = "True" + self._download_artifacts(src_mlflow_client, run.info.run_id, artifact_path) + params, tags = self._update_params_tags_runs( + params, + tags, + dst_run_id, + run.info.run_id, + run_type="server2server", + ) - tags["offlineRun"] = "True" + else: + tags["offlineRun"] = "True" + params, tags = self._update_params_tags_runs(params, tags, dst_run_id, run.info.run_id, run_type="offline") src_run_dct = { - "params": run.data.params, + "params": params, "metrics": _get_metrics_with_steps(src_mlflow_client, run), "tags": tags, "inputs": _inputs_to_dict(run.inputs), } try: + LOGGER.info("Starting to export run data") + import_run_data( dest_mlflow_client, src_run_dct, @@ -255,19 +358,22 @@ def sync( ) _import_inputs(http_client, src_run_dct, dst_run_id) - path = Path(self.source_tracking_uri, run.info.experiment_id, self.run_id, "artifacts") - if path.exists(): - mlflow.set_tracking_uri(self.dest_tracking_uri) - dest_mlflow_client.log_artifacts(dst_run_id, path) + mlflow.set_tracking_uri(self.dest_tracking_uri) + dest_mlflow_client.log_artifacts(dst_run_id, artifact_path) dest_mlflow_client.set_terminated(dst_run_id, RunStatus.to_string(RunStatus.FINISHED)) - except Exception as e: + except BaseException: dest_mlflow_client.set_terminated(dst_run_id, RunStatus.to_string(RunStatus.FAILED)) import traceback traceback.print_exc() - raise Exception(e, "Importing run %s of experiment %s failed", dst_run_id, exp.name) from e # noqa: TRY002 + LOGGER.exception( + "Importing run %s of experiment %s failed", + dst_run_id, + self.experiment_name, + ) - LOGGER.info("Imported run %s into experiment %s", dst_run_id, self.experiment_name) + finally: + close_and_clean_temp(server2server, artifact_path) - temp.close() + LOGGER.info("Imported run %s into experiment %s", dst_run_id, self.experiment_name) diff --git a/src/anemoi/training/utils/seeding.py b/src/anemoi/training/utils/seeding.py index d766bd1a..3b4afd47 100644 --- a/src/anemoi/training/utils/seeding.py +++ b/src/anemoi/training/utils/seeding.py @@ -1,7 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 7bdd5cbd..0d97f25f 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import numpy as np diff --git a/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py b/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py index 75efae9a..db44dff7 100644 --- a/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py +++ b/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 711b1e2d..46163e8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/tests/diagnostics/mlflow/test_auth.py b/tests/diagnostics/mlflow/test_auth.py index f65d836b..329f7217 100644 --- a/tests/diagnostics/mlflow/test_auth.py +++ b/tests/diagnostics/mlflow/test_auth.py @@ -1,11 +1,15 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations +import os import time import pytest @@ -46,13 +50,9 @@ def mocks( mocker.patch( "anemoi.training.diagnostics.mlflow.auth.save_config", ) - mocker.patch( - "anemoi.training.diagnostics.mlflow.auth.input", - return_value="username", - ) mocker.patch( "anemoi.training.diagnostics.mlflow.auth.getpass", - return_value="password", + return_value="seed_refresh_token", ) mocker.patch("os.environ") @@ -103,21 +103,20 @@ def test_login(mocker: pytest.MockerFixture) -> None: auth = TokenAuth("https://test.url") auth.login() - mock_token_request.assert_called_once_with(username="username", password="password") # noqa: S106 + mock_token_request.assert_called_once() # forced credential login mock_token_request = mocks(mocker) auth = TokenAuth("https://test.url") auth.login(force_credentials=True) - mock_token_request.assert_called_once_with(username="username", password="password") # noqa: S106 + mock_token_request.assert_called_once() # failed login mock_token_request = mocks(mocker, token_request={"refresh_token": None}) auth = TokenAuth("https://test.url") pytest.raises(RuntimeError, auth.login) - mock_token_request.assert_called_with(username="username", password="password") # noqa: S106 assert mock_token_request.call_count == 2 @@ -153,3 +152,11 @@ def test_api(mocker: pytest.MockerFixture) -> None: with pytest.raises(RuntimeError): auth._request("path", {"key": "value"}) + + +def test_target_env_var(mocker: pytest.MockerFixture) -> None: + mocks(mocker) + auth = TokenAuth("https://test.url", target_env_var="MLFLOW_TEST_ENV_VAR") + auth.authenticate() + + os.environ.__setitem__.assert_called_once_with("MLFLOW_TEST_ENV_VAR", "access_token") diff --git a/tests/diagnostics/mlflow/test_client.py b/tests/diagnostics/mlflow/test_client.py new file mode 100644 index 00000000..f6dedbce --- /dev/null +++ b/tests/diagnostics/mlflow/test_client.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + import pytest_mock + +from anemoi.training.diagnostics.mlflow.client import AnemoiMlflowClient + + +@pytest.fixture(autouse=True) +def mocks(mocker: pytest_mock.MockerFixture) -> None: + mocker.patch("anemoi.training.diagnostics.mlflow.client.TokenAuth") + mocker.patch("anemoi.training.diagnostics.mlflow.client.health_check") + mocker.patch("anemoi.training.diagnostics.mlflow.client.AnemoiMlflowClient.search_experiments") + + +def test_auth_injected() -> None: + client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=False) + client.search_experiments() + client.search_experiments() + + assert client.anemoi_auth.authenticate.call_count == 2 + + +def test_health_check() -> None: + # the internal health check will trigger an authenticate call + client = AnemoiMlflowClient("http://localhost:5000", authentication=True, check_health=True) + + client.anemoi_auth.authenticate.assert_called_once() diff --git a/tests/diagnostics/test_checkpoint.py b/tests/diagnostics/test_checkpoint.py index e2dce376..63e6ccc9 100644 --- a/tests/diagnostics/test_checkpoint.py +++ b/tests/diagnostics/test_checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import datetime diff --git a/tests/hydra/test_search_path_plugins.py b/tests/hydra/test_search_path_plugins.py index dd981b66..48666502 100644 --- a/tests/hydra/test_search_path_plugins.py +++ b/tests/hydra/test_search_path_plugins.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from hydra import initialize from hydra.core.global_hydra import GlobalHydra from hydra.core.plugins import Plugins diff --git a/tests/train/test_loss_scaling.py b/tests/train/test_loss_scaling.py index 84ca2189..2da5ae00 100644 --- a/tests/train/test_loss_scaling.py +++ b/tests/train/test_loss_scaling.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import pytest import torch from _pytest.fixtures import SubRequest diff --git a/tests/utils/test_usable_indices.py b/tests/utils/test_usable_indices.py index 6bc5c83f..0aff358a 100644 --- a/tests/utils/test_usable_indices.py +++ b/tests/utils/test_usable_indices.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import numpy as np from anemoi.training.utils.usable_indices import get_usable_indices