diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 02eba1cb..3e76fc9a 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -29,6 +29,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +from anemoi.models.layers.mapper import GraphEdgeMixin from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only @@ -46,6 +47,7 @@ from typing import Any import pytorch_lightning as pl + from anemoi.models.layers.graph import NamedNodesAttributes from omegaconf import OmegaConf LOGGER = logging.getLogger(__name__) @@ -645,6 +647,26 @@ def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None Override for frequency to plot at, by default None """ super().__init__(config, every_n_epochs=every_n_epochs) + self.q_extreme_limit = config.get("quantile_edges_to_represent", 0.05) + + def get_node_trainable_tensors(self, node_attributes: NamedNodesAttributes) -> dict[str, torch.Tensor]: + return { + name: tt.trainable for name, tt in node_attributes.trainable_tensors.items() if tt.trainable is not None + } + + def get_edge_trainable_modules(self, model) -> dict[tuple[str, str], torch.Tensor]: + trainable_modules = { + (model._graph_name_data, model._graph_name_hidden): model.encoder, + (model._graph_name_hidden, model._graph_name_data): model.decoder, + } + + if isinstance(model.processor, GraphEdgeMixin): + trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor + + trainable_tensors = { + name: module for name, module in trainable_modules.items() if module.trainable.trainable is not None + } + return trainable_tensors @rank_zero_only def _plot( @@ -656,25 +678,31 @@ def _plot( _ = epoch model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - fig = plot_graph_node_features(model, datashader=self.datashader_plotting) + if len(node_trainable_tensors := self.get_node_trainable_tensors(model.node_attributes)): + fig = plot_graph_node_features(model, node_trainable_tensors, datashader=self.datashader_plotting) - self._output_figure( - trainer.logger, - fig, - epoch=trainer.current_epoch, - tag="node_trainable_params", - exp_log_tag="node_trainable_params", - ) + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="node_trainable_params", + exp_log_tag="node_trainable_params", + ) + else: + LOGGER.warning("There are no trainable node attributes to plot.") - fig = plot_graph_edge_features(model) + if len(edge_trainable_modules := self.get_edge_trainable_modules(model)): + fig = plot_graph_edge_features(model, edge_trainable_modules, q_extreme_limit=self.q_extreme_limit) - self._output_figure( - trainer.logger, - fig, - epoch=trainer.current_epoch, - tag="edge_trainable_params", - exp_log_tag="edge_trainable_params", - ) + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="edge_trainable_params", + exp_log_tag="edge_trainable_params", + ) + else: + LOGGER.warning("There are no trainable edge attributes to plot.") class PlotLoss(BasePerBatchPlotCallback): diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 5944ea40..b2309dee 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -18,7 +18,6 @@ import matplotlib.style as mplstyle import numpy as np import pandas as pd -from anemoi.models.layers.mapper import GraphEdgeMixin from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection from matplotlib.collections import PathCollection @@ -35,7 +34,7 @@ if TYPE_CHECKING: from matplotlib.figure import Figure - from torch import nn + from torch import nn, Tensor from dataclasses import dataclass @@ -874,13 +873,17 @@ def edge_plot( fig.colorbar(psc, ax=ax) -def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figure: +def plot_graph_node_features( + model: nn.Module, trainable_tensors: dict[str, Tensor], datashader: bool = False, +) -> Figure: """Plot trainable graph node features. Parameters ---------- model: AneomiModelEncProcDec Model object + trainable_tensors: dict[str, torch.Tensor] + Node trainable tensors datashader: bool, optional Scatter plot, by default False @@ -889,12 +892,8 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu Figure Figure object handle """ - nrows = len(nodes_name := model._graph_data.node_types) - trainable_tensors = {name: model.node_attributes.trainable_tensors[name].trainable for name in nodes_name} - ncols = min(0 if tt is None else tt.shape[0] for tt in trainable_tensors.values()) - if ncols == 0: - LOGGER.warning("There are no trainable node attributes to plot.") - return None + nrows = len(trainable_tensors) + ncols = max(tt.shape[1] for tt in trainable_tensors.values()) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) @@ -920,13 +919,17 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu return fig -def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure: +def plot_graph_edge_features( + model: nn.Module, trainable_modules: dict[tuple[str, str], Tensor], q_extreme_limit: float = 0.05, +) -> Figure: """Plot trainable graph edge features. Parameters ---------- model: AneomiModelEncProcDec Model object + trainable_modules: dict[tuple[str, str], torch.Tensor] + Edge trainable tensors. q_extreme_limit : float, optional Plot top & bottom quantile of edges trainable values, by default 0.05 (5%). @@ -935,21 +938,8 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure Figure object handle """ - trainable_modules = { - (model._graph_name_data, model._graph_name_hidden): model.encoder, - (model._graph_name_hidden, model._graph_name_data): model.decoder, - } - - if isinstance(model.processor, GraphEdgeMixin): - trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor - - trainable_tensors = {name: module.trainable.trainable for name, module in trainable_modules.items()} - ncols = min(0 if tt is None else tt.shape[1] for tt in trainable_tensors.values()) - if ncols == 0: - LOGGER.warning("There are no trainable edge attributes to plot.") - return None - nrows = len(trainable_modules) + ncols = max(tt.trainable.trainable.shape[1] for tt in trainable_modules.values()) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) @@ -957,7 +947,7 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> src_coords = model.node_attributes.get_coordinates(src).cpu().numpy() dst_coords = model.node_attributes.get_coordinates(dst).cpu().numpy() edge_index = graph_mapper.edge_index_base.cpu().numpy() - edge_features = trainable_tensors[src, dst].cpu().detach().numpy() + edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy() for i in range(ncols): ax_ = ax[row, i] if ncols > 1 else ax[row]