diff --git a/CHANGELOG.md b/CHANGELOG.md index cccb15a1..9871a6fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,47 +8,67 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.2...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) +### Fixed + +### Added +- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) + +### Changed + +### Removed +- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) + + + +## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 + +### Changed +- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) ### Fixed + - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) - - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) - - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) + - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) + - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) + - Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) - - Enable longer validation rollout than training + - Enable longer validation rollout than training + - Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) - - Save entire config in mlflow + - Save entire config in mlflow + + ### Added + - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) - Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) - - Add without subsetting in ScaleTensor + - Add without subsetting in ScaleTensor + - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) -- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) +- Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) ### Changed + - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) - Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) ### Removed -- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 - ### Changed - Lock python version <3.13 [#107](https://github.com/ecmwf/anemoi-training/pull/107) - - ## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24 ### Added @@ -90,6 +110,7 @@ Keep it human-readable, your future self will thank you! - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) + #### Functionality - Enable the callback for plotting a histogram for variables containing NaNs @@ -101,7 +122,6 @@ Keep it human-readable, your future self will thank you! - Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots - ### Fixed - Fix `TypeError` raised when trying to JSON serialise `datetime.timedelta` object - [#43](https://github.com/ecmwf/anemoi-training/pull/43) diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml index 642e6e6b..5eece654 100644 --- a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -60,8 +60,10 @@ callbacks: - 10u - 10v - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + # for rollout and video_rollout pick any integers below dataloader.validation_rollout rollout: - ${dataloader.validation_rollout} + video_rollout: ${dataloader.validation_rollout} every_n_epochs: 20 sample_idx: ${diagnostics.plot.sample_idx} parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 171eb840..b13d8727 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -24,6 +24,7 @@ from pathlib import Path from typing import TYPE_CHECKING +import matplotlib.animation as animation import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np @@ -31,6 +32,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only +from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_edge_features from anemoi.training.diagnostics.plots import plot_graph_node_features @@ -119,6 +121,35 @@ def _output_figure( plt.close(fig) # cleanup + @rank_zero_only + def _output_gif( + self, + logger: pl.loggers.base.LightningLoggerBase, + fig: plt.Figure, + anim: animation.ArtistAnimation, + epoch: int, + tag: str = "gnn", + ) -> None: + """Animation output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.gif", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + anim.save(save_path, writer="pillow", fps=8) + + if self.config.diagnostics.log.wandb.enabled: + LOGGER.warning("Saving gif animations not tested for wandb.") + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + @rank_zero_only def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None: """To execute the plot function but ensuring we catch any errors.""" @@ -261,7 +292,27 @@ def on_validation_epoch_end( class LongRolloutPlots(BasePlotCallback): - """Evaluates the model performance over a (longer) rollout window.""" + """Evaluates the model performance over a (longer) rollout window. + + This function allows evaluating the performance of the model over an extended number + of rollout steps to observe long-term behavior. + Add the callback to the configuration file as follows: + ``` + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + video_rollout: ${dataloader.validation_rollout} + every_n_epochs: 1 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} + ``` + The selected rollout steps for plots and video need to be lower or equal to dataloader.validation_rollout. + Increasing dataloader.validation_rollout has no effect on the rollout steps during training. + It ensures, that enough time steps are available for the plots and video in the validation batches. + + The runtime of creating one animation of one variable for 56 rollout steps is about 1 minute. + Recommended use for video generation: Fork the run using fork_run_id for 1 additional epochs and enabled videos. + """ def __init__( self, @@ -269,10 +320,12 @@ def __init__( rollout: list[int], sample_idx: int, parameters: list[str], + video_rollout: int = 0, accumulation_levels_plot: list[float] | None = None, cmap_accumulation: list[str] | None = None, per_sample: int = 6, every_n_epochs: int = 1, + animation_interval: int = 400, ) -> None: """Initialise LongRolloutPlots callback. @@ -286,6 +339,8 @@ def __init__( Sample to plot parameters : list[str] Parameters to plot + video_rollout : int, optional + Number of rollout steps for video, by default 0 (no video) accumulation_levels_plot : list[float] | None Accumulation levels to plot, by default None cmap_accumulation : list[str] | None @@ -294,22 +349,39 @@ def __init__( Number of plots per sample, by default 6 every_n_epochs : int, optional Epoch frequency to plot at, by default 1 + animation_interval : int, optional + Delay between frames in the animation in milliseconds, by default 400 """ super().__init__(config) self.every_n_epochs = every_n_epochs - LOGGER.debug( - "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - rollout, - every_n_epochs, - ) self.rollout = rollout + self.video_rollout = video_rollout + self.max_rollout = 0 + if self.rollout: + self.max_rollout = max(self.rollout) + else: + self.rollout = [] + if self.video_rollout: + self.max_rollout = max(self.max_rollout, self.video_rollout) + self.sample_idx = sample_idx self.accumulation_levels_plot = accumulation_levels_plot self.cmap_accumulation = cmap_accumulation self.per_sample = per_sample self.parameters = parameters + self.animation_interval = animation_interval + + LOGGER.info( + ( + "Setting up callback for plots with long rollout: rollout for plots = %s, ", + "rollout for video = %s, frequency = every %d epoch.", + ), + self.rollout, + self.video_rollout, + every_n_epochs, + ) @rank_zero_only def _plot( @@ -322,12 +394,10 @@ def _plot( epoch: int, ) -> None: _ = output - start_time = time.time() - logger = trainer.logger - # Build dictionary of inidicies and parameters to be plotted + # Initialize required variables for plotting plot_parameters_dict = { pl_module.data_indices.model.output.name_to_index[name]: ( name, @@ -335,15 +405,12 @@ def _plot( ) for name in self.parameters } - if self.post_processors is None: - # Copy to be used across all the training cycle self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() if self.latlons is None: self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( + assert batch.shape[1] >= self.max_rollout + pl_module.multi_step, ( "Batch length not sufficient for requested validation rollout length! " f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" ) @@ -358,54 +425,175 @@ def _plot( ].cpu() data_0 = self.post_processors(input_tensor_0).numpy() - # start rollout + if self.video_rollout: + data_over_time = [] + # collect min and max values for each variable for the colorbar + vmin, vmax = (np.inf * np.ones(len(plot_parameters_dict)), -np.inf * np.ones(len(plot_parameters_dict))) + + # Plot for each rollout step# Plot for each rollout step with torch.no_grad(): for rollout_step, (_, _, y_pred) in enumerate( pl_module.rollout_step( batch, - rollout=max(self.rollout), + rollout=self.max_rollout, validation_mode=False, training_mode=False, ), ): - + # plot only if the current rollout step is in the list of rollout steps if (rollout_step + 1) in self.rollout: - # prepare true output tensor for plotting - input_tensor_rollout_step = input_batch[ - self.sample_idx, - pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() - - # prepare predicted output tensor for plotting - output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), - ).numpy() - - fig = plot_predicted_multilevel_flat_sample( + self._plot_rollout_step( + pl_module, plot_parameters_dict, - self.per_sample, - self.latlons, - self.accumulation_levels_plot, - self.cmap_accumulation, - data_0.squeeze(), - data_rollout_step.squeeze(), - output_tensor[0, 0, :, :], # rolloutstep, first member + input_batch, + data_0, + rollout_step, + y_pred, + batch_idx, + epoch, + logger, ) - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{local_rank:01d}", + if self.video_rollout and rollout_step < self.video_rollout: + data_over_time, vmin, vmax = self._store_video_frame_data( + data_over_time, + y_pred, + plot_parameters_dict, + vmin, + vmax, ) - LOGGER.info( - "Time taken to plot samples after longer rollout: %s seconds", - int(time.time() - start_time), + + # Generate and save video rollout animation if enabled + if self.video_rollout: + self._generate_video_rollout( + data_0, + data_over_time, + plot_parameters_dict, + vmin, + vmax, + self.video_rollout, + batch_idx, + epoch, + logger, + animation_interval=self.animation_interval, + ) + + LOGGER.info("Time taken to plot/animate samples for longer rollout: %d seconds", int(time.time() - start_time)) + + def _plot_rollout_step( + self, + pl_module: pl.LightningModule, + plot_parameters_dict: dict, + input_batch: torch.Tensor, + data_0: np.ndarray, + rollout_step: int, + y_pred: torch.Tensor, + batch_idx: int, + epoch: int, + logger: pl.loggers.base.LightningLoggerBase, + ) -> None: + """Plot the predicted output, input, true target and error plots for a given rollout step.""" + # prepare true output tensor for plotting + input_tensor_rollout_step = input_batch[ + self.sample_idx, + pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() + # predicted output tensor + output_tensor = self.post_processors(y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()).numpy() + + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data_0.squeeze(), + data_rollout_step.squeeze(), + output_tensor[0, 0, :, :], # rolloutstep, first member ) + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{pl_module.local_rank:01d}", + ) + + def _store_video_frame_data( + self, + data_over_time: list, + y_pred: torch.Tensor, + plot_parameters_dict: dict, + vmin: np.ndarray, + vmax: np.ndarray, + ) -> tuple[list, np.ndarray, np.ndarray]: + """Store the data for each frame of the video.""" + # prepare predicted output tensors for video + output_tensor = self.post_processors(y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()).numpy() + data_over_time.append(output_tensor[0, 0, :, np.array(list(plot_parameters_dict.keys()))]) + # update min and max values for each variable for the colorbar + vmin[:] = np.minimum(vmin, np.nanmin(data_over_time[-1], axis=1)) + vmax[:] = np.maximum(vmax, np.nanmax(data_over_time[-1], axis=1)) + return data_over_time, vmin, vmax + + def _generate_video_rollout( + self, + data_0: np.ndarray, + data_over_time: list, + plot_parameters_dict: dict, + vmin: np.ndarray, + vmax: np.ndarray, + rollout_step: int, + batch_idx: int, + epoch: int, + logger: pl.loggers.base.LightningLoggerBase, + animation_interval: int = 400, + ) -> None: + """Generate the video animation for the rollout.""" + for idx, (variable_idx, (variable_name, _)) in enumerate(plot_parameters_dict.items()): + # Create the animation and list to store the frames (artists) + frames = [] + # Prepare the figure + fig, ax = plt.subplots(figsize=(10, 6), dpi=72) + cmap = "twilight" if variable_name == "mwd" else "viridis" + + # Create initial data and colorbar + ax, scatter_frame = get_scatter_frame( + ax, + data_0[0, :, variable_idx], + self.latlons, + cmap=cmap, + vmin=vmin[idx], + vmax=vmax[idx], + ) + ax.set_title(f"{variable_name}") + fig.colorbar(scatter_frame, ax=ax) + frames.append([scatter_frame]) + + # Loop through the data and create the scatter plot for each frame + for frame_data in data_over_time: + ax, scatter_frame = get_scatter_frame( + ax, + frame_data[idx], + self.latlons, + cmap=cmap, + vmin=vmin[idx], + vmax=vmax[idx], + ) + frames.append([scatter_frame]) # Each frame contains a list of artists (images) + + # Create the animation using ArtistAnimation + anim = animation.ArtistAnimation(fig, frames, interval=animation_interval, blit=True) + self._output_gif( + logger, + fig, + anim, + epoch=epoch, + tag=f"gnn_pred_val_animation_{variable_name}_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + ) @rank_zero_only def on_validation_batch_end( diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 698c7c50..2eb82113 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -10,6 +10,7 @@ from __future__ import annotations import logging +import os from pathlib import Path from typing import TYPE_CHECKING @@ -27,14 +28,28 @@ def get_mlflow_logger(config: DictConfig) -> None: LOGGER.debug("MLFlow logging is disabled.") return None + # 35 retries allow for 1 hour of server downtime + http_max_retries = config.diagnostics.log.mlflow.get("http_max_retries", 35) + + os.environ["MLFLOW_HTTP_REQUEST_MAX_RETRIES"] = str(http_max_retries) + os.environ["_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT"] = str(http_max_retries + 1) + # these are the default values, but set them explicitly in case they change + os.environ["MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR"] = "2" + os.environ["MLFLOW_HTTP_REQUEST_BACKOFF_JITTER"] = "1" + from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger resumed = config.training.run_id is not None forked = config.training.fork_run_id is not None save_dir = config.hardware.paths.logs.mlflow - tracking_uri = config.diagnostics.log.mlflow.tracking_uri + offline = config.diagnostics.log.mlflow.offline + if not offline: + tracking_uri = config.diagnostics.log.mlflow.tracking_uri + LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) + else: + tracking_uri = None if (resumed or forked) and (offline): # when resuming or forking offline - # tracking_uri = ${hardware.paths.logs.mlflow} @@ -54,7 +69,6 @@ def get_mlflow_logger(config: DictConfig) -> None: ) log_hyperparams = False - LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) logger = AnemoiMLflowLogger( experiment_name=config.diagnostics.log.mlflow.experiment_name, project_name=config.diagnostics.log.mlflow.project_name, diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index d397f05c..e0b44d1c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -21,6 +21,7 @@ from anemoi.models.layers.mapper import GraphEdgeMixin from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection +from matplotlib.collections import PathCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm @@ -50,6 +51,13 @@ class LatLonData: data: np.ndarray +def equirectangular_projection(latlons: np.array) -> np.array: + pc = EquirectangularProjection() + lat, lon = latlons[:, 0], latlons[:, 1] + pc_lon, pc_lat = pc(lon, lat) + return pc_lat, pc_lon + + def init_plot_settings() -> None: """Initialize matplotlib plot settings.""" small_font_size = 8 @@ -159,9 +167,8 @@ def plot_power_spectrum( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) - pc = EquirectangularProjection() - lat, lon = latlons[:, 0], latlons[:, 1] - pc_lon, pc_lat = pc(lon, lat) + pc_lat, pc_lon = equirectangular_projection(latlons) + pc_lon = np.array(pc_lon) pc_lat = np.array(pc_lat) # Calculate delta_lon and delta_lat on the projected grid @@ -381,9 +388,7 @@ def plot_predicted_multilevel_flat_sample( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) - pc = EquirectangularProjection() - lat, lon = latlons[:, 0], latlons[:, 1] - pc_lon, pc_lat = pc(lon, lat) + pc_lat, pc_lon = equirectangular_projection(latlons) for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): xt = x[..., variable_idx].squeeze() * int(output_only) @@ -724,6 +729,36 @@ def single_plot( fig.colorbar(psc, ax=ax) +def get_scatter_frame( + ax: plt.Axes, + data: np.ndarray, + latlons: np.ndarray, + cmap: str = "viridis", + vmin: int | None = None, + vmax: int | None = None, +) -> [plt.Axes, PathCollection]: + """Create a scatter plot for a single frame of an animation.""" + pc_lat, pc_lon = equirectangular_projection(latlons) + + scatter_frame = ax.scatter( + pc_lon, + pc_lat, + c=data, + cmap=cmap, + s=5, + alpha=1.0, + rasterized=True, + vmin=vmin, + vmax=vmax, + ) + ax.set_xlim((-np.pi, np.pi)) + ax.set_ylim((-np.pi / 2, np.pi / 2)) + continents.plot_continents(ax) + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + return ax, scatter_frame + + def edge_plot( fig: Figure, ax: plt.Axes,