Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: update plot of trainable params
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Dec 16, 2024
1 parent eabe659 commit e2e93f2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 41 deletions.
60 changes: 44 additions & 16 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from anemoi.models.layers.mapper import GraphEdgeMixin
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only

Expand All @@ -46,6 +47,7 @@
from typing import Any

import pytorch_lightning as pl
from anemoi.models.layers.graph import NamedNodesAttributes
from omegaconf import OmegaConf

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -645,6 +647,26 @@ def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None
Override for frequency to plot at, by default None
"""
super().__init__(config, every_n_epochs=every_n_epochs)
self.q_extreme_limit = config.get("quantile_edges_to_represent", 0.05)

def get_node_trainable_tensors(self, node_attributes: NamedNodesAttributes) -> dict[str, torch.Tensor]:
return {
name: tt.trainable for name, tt in node_attributes.trainable_tensors.items() if tt.trainable is not None
}

def get_edge_trainable_modules(self, model) -> dict[tuple[str, str], torch.Tensor]:
trainable_modules = {
(model._graph_name_data, model._graph_name_hidden): model.encoder,
(model._graph_name_hidden, model._graph_name_data): model.decoder,
}

if isinstance(model.processor, GraphEdgeMixin):
trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor

trainable_tensors = {
name: module for name, module in trainable_modules.items() if module.trainable.trainable is not None
}
return trainable_tensors

@rank_zero_only
def _plot(
Expand All @@ -656,25 +678,31 @@ def _plot(
_ = epoch
model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model

fig = plot_graph_node_features(model, datashader=self.datashader_plotting)
if len(node_trainable_tensors := self.get_node_trainable_tensors(model.node_attributes)):
fig = plot_graph_node_features(model, node_trainable_tensors, datashader=self.datashader_plotting)

self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag="node_trainable_params",
exp_log_tag="node_trainable_params",
)
self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag="node_trainable_params",
exp_log_tag="node_trainable_params",
)
else:
LOGGER.warning("There are no trainable node attributes to plot.")

fig = plot_graph_edge_features(model)
if len(edge_trainable_modules := self.get_edge_trainable_modules(model)):
fig = plot_graph_edge_features(model, edge_trainable_modules, q_extreme_limit=self.q_extreme_limit)

self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag="edge_trainable_params",
exp_log_tag="edge_trainable_params",
)
self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag="edge_trainable_params",
exp_log_tag="edge_trainable_params",
)
else:
LOGGER.warning("There are no trainable edge attributes to plot.")


class PlotLoss(BasePerBatchPlotCallback):
Expand Down
40 changes: 15 additions & 25 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import matplotlib.style as mplstyle
import numpy as np
import pandas as pd
from anemoi.models.layers.mapper import GraphEdgeMixin
from datashader.mpl_ext import dsshow
from matplotlib.collections import LineCollection
from matplotlib.collections import PathCollection
Expand All @@ -35,7 +34,7 @@

if TYPE_CHECKING:
from matplotlib.figure import Figure
from torch import nn
from torch import nn, Tensor

from dataclasses import dataclass

Expand Down Expand Up @@ -874,13 +873,17 @@ def edge_plot(
fig.colorbar(psc, ax=ax)


def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figure:
def plot_graph_node_features(
model: nn.Module, trainable_tensors: dict[str, Tensor], datashader: bool = False,
) -> Figure:
"""Plot trainable graph node features.
Parameters
----------
model: AneomiModelEncProcDec
Model object
trainable_tensors: dict[str, torch.Tensor]
Node trainable tensors
datashader: bool, optional
Scatter plot, by default False
Expand All @@ -889,12 +892,8 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu
Figure
Figure object handle
"""
nrows = len(nodes_name := model._graph_data.node_types)
trainable_tensors = {name: model.node_attributes.trainable_tensors[name].trainable for name in nodes_name}
ncols = min(0 if tt is None else tt.shape[0] for tt in trainable_tensors.values())
if ncols == 0:
LOGGER.warning("There are no trainable node attributes to plot.")
return None
nrows = len(trainable_tensors)
ncols = max(tt.shape[1] for tt in trainable_tensors.values())

figsize = (ncols * 4, nrows * 3)
fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT)
Expand All @@ -920,13 +919,17 @@ def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figu
return fig


def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure:
def plot_graph_edge_features(
model: nn.Module, trainable_modules: dict[tuple[str, str], Tensor], q_extreme_limit: float = 0.05,
) -> Figure:
"""Plot trainable graph edge features.
Parameters
----------
model: AneomiModelEncProcDec
Model object
trainable_modules: dict[tuple[str, str], torch.Tensor]
Edge trainable tensors.
q_extreme_limit : float, optional
Plot top & bottom quantile of edges trainable values, by default 0.05 (5%).
Expand All @@ -935,29 +938,16 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) ->
Figure
Figure object handle
"""
trainable_modules = {
(model._graph_name_data, model._graph_name_hidden): model.encoder,
(model._graph_name_hidden, model._graph_name_data): model.decoder,
}

if isinstance(model.processor, GraphEdgeMixin):
trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor

trainable_tensors = {name: module.trainable.trainable for name, module in trainable_modules.items()}
ncols = min(0 if tt is None else tt.shape[1] for tt in trainable_tensors.values())
if ncols == 0:
LOGGER.warning("There are no trainable edge attributes to plot.")
return None

nrows = len(trainable_modules)
ncols = max(tt.trainable.trainable.shape[1] for tt in trainable_modules.values())
figsize = (ncols * 4, nrows * 3)
fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT)

for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()):
src_coords = model.node_attributes.get_coordinates(src).cpu().numpy()
dst_coords = model.node_attributes.get_coordinates(dst).cpu().numpy()
edge_index = graph_mapper.edge_index_base.cpu().numpy()
edge_features = trainable_tensors[src, dst].cpu().detach().numpy()
edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy()

for i in range(ncols):
ax_ = ax[row, i] if ncols > 1 else ax[row]
Expand Down

0 comments on commit e2e93f2

Please sign in to comment.