Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
docs: add description to models schema
Browse files Browse the repository at this point in the history
  • Loading branch information
chebertpinard committed Dec 19, 2024
1 parent 937536f commit d7a8d93
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 169 deletions.
6 changes: 3 additions & 3 deletions src/anemoi/training/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from .diagnostics import DiagnosticsSchema # noqa: TC001
from .graphs.base_graph import BaseGraphSchema # noqa: TC001
from .hardware import HardwareSchema # noqa: TC001
from .models.gnn import GNNConfig # noqa: TC001
from .models.graph_transformer import GraphTransformerConfig # noqa: TC001
from .models.transformer import TransformerConfig # noqa: TC001
from .models.models import GNNConfig # noqa: TC001
from .models.models import GraphTransformerConfig # noqa: TC001
from .models.models import TransformerConfig # noqa: TC001
from .training import TrainingSchema # noqa: TC001


Expand Down
56 changes: 0 additions & 56 deletions src/anemoi/training/schemas/models/base_model.py

This file was deleted.

40 changes: 40 additions & 0 deletions src/anemoi/training/schemas/models/common_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from pydantic import BaseModel
from pydantic import Field
from pydantic import NonNegativeInt


class TransformerModelComponent(BaseModel):
activation: str = Field(default="GELU")
"Activation function to use for the transformer model component. Default to GELU."
convert_: str = Field("all", alias="_convert_")
"Target's parameters to convert to primitive containers. Other parameters will use OmegaConf. Default to all."
trainable_size: NonNegativeInt = Field(default=8)
"Size of trainable parameters vector. Default to 8."
num_chunks: NonNegativeInt = Field(default=1)
"Number of chunks to divide the layer into. Default to 1."
mlp_hidden_ratio: NonNegativeInt = Field(default=4)
"Ratio of mlp hidden dimension to embedding dimension. Default to 4."
num_heads: NonNegativeInt = Field(default=16)
"Number of attention heads. Default to 16."


class GNNModelComponent(BaseModel):
activation: str = Field(default="GELU")
"Activation function to use for the GNN model component. Default to GELU."
trainable_size: NonNegativeInt = Field(default=8)
"Size of trainable parameters vector. Default to 8."
num_chunks: NonNegativeInt = Field(default=1)
"Number of chunks to divide the layer into. Default to 1."
sub_graph_edge_attributes: list[str] = Field(default_factory=list)
"Edge attributes to consider in the model component features."
mlp_extra_layers: NonNegativeInt = Field(default=0)
"The number of extra hidden layers in MLP. Default to 0."
27 changes: 27 additions & 0 deletions src/anemoi/training/schemas/models/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from typing import Literal

from pydantic import Field

from .common_components import GNNModelComponent
from .common_components import TransformerModelComponent


class GraphTransformerDecoder(TransformerModelComponent):
target_: Literal["anemoi.models.layers.mapper.GraphTransformerBackwardMapper"] = Field(..., alias="_target_")
"Graph Transformer Decoder object from anemoi.models.layers.mapper."
sub_graph_edge_attributes: list[str] = Field(default=["edge_length", "edge_dirs"])
"Edge attributes to consider in the decoder features. Default to [edge_length, edge_dirs]"


class GNNDecoder(GNNModelComponent):
target_: Literal["anemoi.models.layers.mapper.GNNBackwardMapper"] = Field(..., alias="_target_")
"GNN decoder object from anemoi.models.layers.mapper."
27 changes: 27 additions & 0 deletions src/anemoi/training/schemas/models/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from typing import Literal

from pydantic import Field

from .common_components import GNNModelComponent
from .common_components import TransformerModelComponent


class GNNEncoder(GNNModelComponent):
target_: Literal["anemoi.models.layers.mapper.GNNForwardMapper"] = Field(..., alias="_target_")
"GNN encoder object from anemoi.models.layers.mapper."


class GraphTransformerEncoder(TransformerModelComponent):
target_: Literal["anemoi.models.layers.mapper.GraphTransformerForwardMapper"] = Field(..., alias="_target_")
"Graph Transfromer Encoder object from anemoi.models.layers.mapper."
sub_graph_edge_attributes: list[str] = Field(default=["edge_length", "edge_dirs"])
"Edge attributes to consider in the encoder features."
44 changes: 0 additions & 44 deletions src/anemoi/training/schemas/models/gnn.py

This file was deleted.

33 changes: 0 additions & 33 deletions src/anemoi/training/schemas/models/graph_transformer.py

This file was deleted.

133 changes: 133 additions & 0 deletions src/anemoi/training/schemas/models/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from __future__ import annotations

import logging
from enum import Enum
from typing import Any
from typing import Literal

from pydantic import BaseModel
from pydantic import Field
from pydantic import NonNegativeInt
from pydantic import field_validator

from .decoder import GNNDecoder
from .decoder import GraphTransformerDecoder
from .encoder import GNNEncoder
from .encoder import GraphTransformerEncoder
from .processor import GNNProcessor
from .processor import GraphTransformerProcessor
from .processor import TransformerProcessor

LOGGER = logging.getLogger(__name__)


class AllowedModels(str, Enum):
ANEMOI_MODEL_ENC_PROC_DEC = "anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec"


class Model(BaseModel):
target_: AllowedModels = Field(..., alias="_target_")
"Model object defined in anemoi.models.model."
convert_: str = Field("all", alias="_convert_")
"Target's parameters to convert to primitive containers. Other parameters will use OmegaConf. Default to all."


class TrainableParameters(BaseModel):
data: NonNegativeInt = Field(default=8)
"Size of the learnable data node tensor. Default to 8."
hidden: NonNegativeInt = Field(default=8)
"Size of the learnable hidden node tensor. Default to 8."


class ReluBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.ReluBounding"] = Field(..., alias="_target_")
"Relu bounding object defined in anemoi.models.layers.bounding."
variables: list[str]
"List of variables to bound using the Relu method."


class FractionBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.FractionBounding"] = Field(..., alias="_target_")
"Fraction bounding object defined in anemoi.models.layers.bounding."
variables: list[str]
"List of variables to bound using the hard tanh fraction method."
min_val: float
"The minimum value for the HardTanh activation. Correspond to the minimum fraction of the total_var."
max_val: float
"The maximum value for the HardTanh activation. Correspond to the maximum fraction of the total_var."
total_var: str
"Variable from which the secondary variables are derived. \
For example, convective precipitation should be a fraction of total precipitation."


class HardtanhBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.HardtanhBounding"] = Field(..., alias="_target_")
"Hard tanh bounding method function from anemoi.models.layers.bounding."
variables: list[str]
"List of variables to bound using the hard tanh method."
min_val: float
"The minimum value for the HardTanh activation."
max_val: float
"The maximum value for the HardTanh activation."


defined_boundings = [
"anemoi.models.layers.bounding.HardtanhBounding",
"anemoi.models.layers.bounding.FractionBounding",
"anemoi.models.layers.bounding.ReluBounding",
]


class BaseModelConfig(BaseModel):
num_channels: NonNegativeInt = Field(default=512)
"Feature tensor size in the hidden space."
model: Model = Field(default_factory=Model)
"Model schema."
trainable_parameters: TrainableParameters = Field(default_factory=TrainableParameters)
"Learnable node and edge parameters."
bounding: list[ReluBoundingSchema | HardtanhBoundingSchema | FractionBoundingSchema | Any]
"List of bounding configuration applied in order to the specified variables."

@field_validator("bounding")
@classmethod
def validate_bounding_schema_exist(cls, boundings: list) -> list:
for bounding in boundings:
if bounding["_target_"] not in defined_boundings:
LOGGER.warning("%s bounding schema is not defined in anemoi.", bounding["_target_"])
return boundings


class GNNConfig(BaseModelConfig):
processor: GNNProcessor = Field(default_factory=GNNProcessor)
"GNN processor schema."
encoder: GNNEncoder = Field(default_factory=GNNEncoder)
"GNN encoder schema."
decoder: GNNDecoder = Field(default_factory=GNNDecoder)
"GNN decoder schema."


class GraphTransformerConfig(BaseModelConfig):
processor: GraphTransformerProcessor = Field(default_factory=GraphTransformerProcessor)
"Graph transformer processor schema."
encoder: GraphTransformerEncoder = Field(default_factory=GraphTransformerEncoder)
"Graph transformer encoder schema."
decoder: GraphTransformerDecoder = Field(default_factory=GraphTransformerDecoder)
"Graph transformer decoder schema."


class TransformerConfig(BaseModelConfig):
processor: TransformerProcessor = Field(default_factory=TransformerProcessor)
"Transformer processor schema."
encoder: GraphTransformerEncoder = Field(default_factory=GraphTransformerEncoder)
"Graph transformer encoder schema."
decoder: GraphTransformerDecoder = Field(default_factory=GraphTransformerDecoder)
"Graph transformer decoder schema."
Loading

0 comments on commit d7a8d93

Please sign in to comment.