Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Add expansion of params to logger #91

Merged
merged 22 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3e5b376
Add expansion of params to logger
HCookie Oct 15, 2024
df9c1a3
Update CHANGELOG
HCookie Oct 15, 2024
1b8d311
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 15, 2024
6c8663d
Use ‘_’ as seperator
HCookie Oct 15, 2024
51c023b
Remove iterator
HCookie Oct 15, 2024
723e1e9
Provide threshold to expansion
HCookie Oct 16, 2024
d64a5c3
Expand nested lists
HCookie Oct 16, 2024
21bb8a7
Limit expansion
HCookie Oct 21, 2024
2005310
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 22, 2024
2165d7c
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 23, 2024
b91559c
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 24, 2024
a4429ef
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 24, 2024
68048b9
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 25, 2024
7f44057
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 28, 2024
745b812
Address review comments
HCookie Oct 28, 2024
6e38c9e
Allow for config of expand_keys
HCookie Oct 28, 2024
4b9b5dc
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Oct 28, 2024
ad7a5f6
Merge remote-tracking branch 'origin/develop' into fix/mlflow-log_par…
HCookie Oct 29, 2024
1975c86
Merge remote-tracking branch 'origin/develop' into fix/mlflow-log_par…
HCookie Oct 29, 2024
c06cadf
Update changelog
HCookie Oct 29, 2024
5410259
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Nov 6, 2024
3b52df1
Merge branch 'develop' into fix/mlflow-log_params-string-truncation
HCookie Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Keep it human-readable, your future self will thank you!
- Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119)
- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)
- Enable longer validation rollout than training
- Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91)
- Save entire config in mlflow
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/training/config/diagnostics/evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ log:
terminal: True
run_name: null # If set to null, the run name will be the a random UUID
on_resume_create_child: True
expand_hyperparams: # Which keys in hyperparams to expand
- config
interval: 100

enable_progress_bar: True
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/training/diagnostics/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def get_mlflow_logger(config: DictConfig) -> None:
)
config_params = OmegaConf.to_container(config, resolve=True)

logger.log_hyperparams(config_params)
logger.log_hyperparams(
config_params,
expand_keys=config.diagnostics.log.mlflow.get("expand_hyperparams", ["config"]),
)

if config.diagnostics.log.mlflow.terminal:
logger.log_terminal_output(artifact_save_dir=config.hardware.paths.plots)
Expand Down
63 changes: 56 additions & 7 deletions src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning.utilities.rank_zero import rank_zero_only

from anemoi.training.diagnostics.mlflow.auth import TokenAuth
from anemoi.training.diagnostics.mlflow.utils import expand_iterables
from anemoi.training.diagnostics.mlflow.utils import health_check
from anemoi.training.utils.jsonify import map_config_to_primitives

Expand Down Expand Up @@ -483,26 +484,74 @@ def _clean_params(params: dict[str, Any]) -> dict[str, Any]:
return params

@rank_zero_only
def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None:
"""Overwrite the log_hyperparams method to flatten config params using '.'."""
def log_hyperparams_as_artifact(self, params: dict[str, Any] | Namespace) -> None:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
"""Log hyperparameters as an artifact."""
import json
import tempfile
from json import JSONEncoder

class StrEncoder(JSONEncoder):
def default(self, o: Any) -> str:
return str(o)

with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir) / "config.json"
with Path.open(path, "w") as f:
json.dump(params, f, cls=StrEncoder)
self.experiment.log_artifact(run_id=self.run_id, local_path=path)

@rank_zero_only
def log_hyperparams(self, params: dict[str, Any] | Namespace, *, expand_keys: list[str] | None = None) -> None:
"""Overwrite the log_hyperparams method.

- flatten config params using '.'.
- expand keys within params to avoid truncation.
- log hyperparameters as an artifact.

Parameters
----------
params : dict[str, Any] | Namespace
params to log
expand_keys : list[str] | None, optional
keys to expand within params. Any key being expanded will
have lists converted according to `expand_iterables`,
by default None.
"""
if self._flag_log_hparams:
params = _convert_params(params)

# this is needed to resolve optional missing config values to a string, instead of raising a missing error
if config := params.get("config"):
params["config"] = map_config_to_primitives(config)

params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries
params = self._clean_params(params)

import mlflow
from mlflow.entities import Param

# Truncate parameter values.
truncation_length = 250

if Version(mlflow.VERSION) >= Version("1.28.0"):
truncation_length = 500
params_list = [Param(key=k, value=str(v)[:truncation_length]) for k, v in params.items()]

self.log_hyperparams_as_artifact(params)

expanded_params = {}
params = params.copy()

for key in expand_keys or []:
if key in params:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
expanded_params.update(
expand_iterables(params.pop(key), size_threshold=None, delimiter="."),
)
expanded_params.update(params)

expanded_params = _flatten_dict(
expanded_params,
delimiter=".",
) # Flatten dict with '.' to not break API queries
expanded_params = self._clean_params(expanded_params)

# Truncate parameter values.
params_list = [Param(key=k, value=str(v)[:truncation_length]) for k, v in expanded_params.items()]

for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
Expand Down
80 changes: 79 additions & 1 deletion src/anemoi/training/diagnostics/mlflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations


import functools
import os
from typing import Any

import requests

Expand Down Expand Up @@ -36,3 +38,79 @@ def health_check(tracking_uri: str) -> None:
if not token:
error_msg += "The server may require authentication, did you forget to turn it on?"
raise ConnectionError(error_msg)


def expand_iterables(
params: dict[str, Any],
*,
size_threshold: int | None = None,
recursive: bool = True,
delimiter: str = ".",
) -> dict[str, Any]:
"""Expand any iterable values to the form {key.i: value_i}.

If expanded will also add {key.all: [value_0, value_1, ...], key.length: len([value_0, value_1, ...])}.

If `size_threshold` is not None, expand the iterable only if the length of str(value) is
greater than `size_threshold`.

Parameters
----------
params : dict[str, Any]
Parameters to be expanded.
size_threshold : int | None, optional
Threshold of str(value) to expand iterable at.
Default is None.
recursive : bool, optional
Expand nested dictionaries.
Default is True.
delimiter: str, optional
Delimiter to use for keys.
Default is ".".

Returns
-------
dict[str, Any]
Dictionary with all iterable values expanded.

Examples
--------
>>> expand_iterables({'a': ['a', 'b', 'c']})
{'a.0': 'a', 'a.1': 'b', 'a.2': 'c', 'a.all': ['a', 'b', 'c'], 'a.length': 3}
>>> expand_iterables({'a': {'b': ['a', 'b', 'c']}})
{'a': {'b.0': 'a', 'b.1': 'b', 'b.2': 'c', 'b.all': ['a', 'b', 'c'], 'b.length': 3}}
>>> expand_iterables({'a': ['a', 'b', 'c']}, size_threshold=100)
{'a': ['a', 'b', 'c']}
>>> expand_iterables({'a': [[0,1,2], 'b', 'c']})
{'a.0': {0: 0, 1: 1, 2: 2}, 'a.1': 'b', 'a.2': 'c', 'a.all': [[0, 1, 2], 'b', 'c'], 'a.length': 3}
"""

def should_be_expanded(x: Any) -> bool:
return size_threshold is None or len(str(x)) > size_threshold

nested_func = functools.partial(expand_iterables, size_threshold=size_threshold, recursive=recursive)

def expand(val: dict | list) -> dict[str, Any]:
if not recursive:
return val
if isinstance(val, dict):
return nested_func(val)
if isinstance(val, list):
return nested_func(dict(enumerate(val)))
HCookie marked this conversation as resolved.
Show resolved Hide resolved
return val

expanded_params = {}

for key, value in params.items():
if isinstance(value, (list, tuple)):
if should_be_expanded(value):
for i, v in enumerate(value):
expanded_params[f"{key}{delimiter}{i}"] = expand(v)

expanded_params[f"{key}{delimiter}all"] = value
expanded_params[f"{key}{delimiter}length"] = len(value)
else:
expanded_params[key] = value
else:
expanded_params[key] = expand(value)
return expanded_params
46 changes: 46 additions & 0 deletions tests/diagnostics/mlflow/test_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from anemoi.training.diagnostics.mlflow.utils import expand_iterables


def test_expand_iterables_single_iterable() -> None:
# Test case with a single iterable
dictionary = {"a": ["a", "b", "c"]}
expanded = expand_iterables(dictionary)
assert expanded == {"a.0": "a", "a.1": "b", "a.2": "c", "a.all": ["a", "b", "c"], "a.length": 3}


def test_expand_iterables_size_threshold() -> None:
# Test case with a single iterable
dictionary = {"a": ["a", "b", "c"]}
expanded = expand_iterables(dictionary, size_threshold=100)
assert expanded == dictionary


def test_expand_iterables_with_nested_dict() -> None:
dictionary = {"a": {"b": ["a", "b", "c"]}}
expanded = expand_iterables(dictionary)
assert expanded == {"a": {"b.0": "a", "b.1": "b", "b.2": "c", "b.all": ["a", "b", "c"], "b.length": 3}}


def test_expand_iterables_with_nested_dict_thresholded() -> None:
dictionary = {"a": {"b": ["a", "b", "c"]}, "c": ["d"]}
expanded = expand_iterables(dictionary, size_threshold=5)
assert expanded == {"a": {"b.0": "a", "b.1": "b", "b.2": "c", "b.all": ["a", "b", "c"], "b.length": 3}, "c": ["d"]}


def test_expand_iterables_with_nested_list() -> None:
dictionary = {"a": [[0, 1, 2], "b", "c"]}
expanded = expand_iterables(dictionary)
assert expanded == {
"a.0": {0: 0, 1: 1, 2: 2},
"a.1": "b",
"a.2": "c",
"a.all": [[0, 1, 2], "b", "c"],
"a.length": 3,
}
Loading