This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add description to models schema
- Loading branch information
1 parent
937536f
commit d7a8d93
Showing
10 changed files
with
282 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
from pydantic import BaseModel | ||
from pydantic import Field | ||
from pydantic import NonNegativeInt | ||
|
||
|
||
class TransformerModelComponent(BaseModel): | ||
activation: str = Field(default="GELU") | ||
"Activation function to use for the transformer model component. Default to GELU." | ||
convert_: str = Field("all", alias="_convert_") | ||
"Target's parameters to convert to primitive containers. Other parameters will use OmegaConf. Default to all." | ||
trainable_size: NonNegativeInt = Field(default=8) | ||
"Size of trainable parameters vector. Default to 8." | ||
num_chunks: NonNegativeInt = Field(default=1) | ||
"Number of chunks to divide the layer into. Default to 1." | ||
mlp_hidden_ratio: NonNegativeInt = Field(default=4) | ||
"Ratio of mlp hidden dimension to embedding dimension. Default to 4." | ||
num_heads: NonNegativeInt = Field(default=16) | ||
"Number of attention heads. Default to 16." | ||
|
||
|
||
class GNNModelComponent(BaseModel): | ||
activation: str = Field(default="GELU") | ||
"Activation function to use for the GNN model component. Default to GELU." | ||
trainable_size: NonNegativeInt = Field(default=8) | ||
"Size of trainable parameters vector. Default to 8." | ||
num_chunks: NonNegativeInt = Field(default=1) | ||
"Number of chunks to divide the layer into. Default to 1." | ||
sub_graph_edge_attributes: list[str] = Field(default_factory=list) | ||
"Edge attributes to consider in the model component features." | ||
mlp_extra_layers: NonNegativeInt = Field(default=0) | ||
"The number of extra hidden layers in MLP. Default to 0." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
from typing import Literal | ||
|
||
from pydantic import Field | ||
|
||
from .common_components import GNNModelComponent | ||
from .common_components import TransformerModelComponent | ||
|
||
|
||
class GraphTransformerDecoder(TransformerModelComponent): | ||
target_: Literal["anemoi.models.layers.mapper.GraphTransformerBackwardMapper"] = Field(..., alias="_target_") | ||
"Graph Transformer Decoder object from anemoi.models.layers.mapper." | ||
sub_graph_edge_attributes: list[str] = Field(default=["edge_length", "edge_dirs"]) | ||
"Edge attributes to consider in the decoder features. Default to [edge_length, edge_dirs]" | ||
|
||
|
||
class GNNDecoder(GNNModelComponent): | ||
target_: Literal["anemoi.models.layers.mapper.GNNBackwardMapper"] = Field(..., alias="_target_") | ||
"GNN decoder object from anemoi.models.layers.mapper." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
from typing import Literal | ||
|
||
from pydantic import Field | ||
|
||
from .common_components import GNNModelComponent | ||
from .common_components import TransformerModelComponent | ||
|
||
|
||
class GNNEncoder(GNNModelComponent): | ||
target_: Literal["anemoi.models.layers.mapper.GNNForwardMapper"] = Field(..., alias="_target_") | ||
"GNN encoder object from anemoi.models.layers.mapper." | ||
|
||
|
||
class GraphTransformerEncoder(TransformerModelComponent): | ||
target_: Literal["anemoi.models.layers.mapper.GraphTransformerForwardMapper"] = Field(..., alias="_target_") | ||
"Graph Transfromer Encoder object from anemoi.models.layers.mapper." | ||
sub_graph_edge_attributes: list[str] = Field(default=["edge_length", "edge_dirs"]) | ||
"Edge attributes to consider in the encoder features." |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from enum import Enum | ||
from typing import Any | ||
from typing import Literal | ||
|
||
from pydantic import BaseModel | ||
from pydantic import Field | ||
from pydantic import NonNegativeInt | ||
from pydantic import field_validator | ||
|
||
from .decoder import GNNDecoder | ||
from .decoder import GraphTransformerDecoder | ||
from .encoder import GNNEncoder | ||
from .encoder import GraphTransformerEncoder | ||
from .processor import GNNProcessor | ||
from .processor import GraphTransformerProcessor | ||
from .processor import TransformerProcessor | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class AllowedModels(str, Enum): | ||
ANEMOI_MODEL_ENC_PROC_DEC = "anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec" | ||
|
||
|
||
class Model(BaseModel): | ||
target_: AllowedModels = Field(..., alias="_target_") | ||
"Model object defined in anemoi.models.model." | ||
convert_: str = Field("all", alias="_convert_") | ||
"Target's parameters to convert to primitive containers. Other parameters will use OmegaConf. Default to all." | ||
|
||
|
||
class TrainableParameters(BaseModel): | ||
data: NonNegativeInt = Field(default=8) | ||
"Size of the learnable data node tensor. Default to 8." | ||
hidden: NonNegativeInt = Field(default=8) | ||
"Size of the learnable hidden node tensor. Default to 8." | ||
|
||
|
||
class ReluBoundingSchema(BaseModel): | ||
target_: Literal["anemoi.models.layers.bounding.ReluBounding"] = Field(..., alias="_target_") | ||
"Relu bounding object defined in anemoi.models.layers.bounding." | ||
variables: list[str] | ||
"List of variables to bound using the Relu method." | ||
|
||
|
||
class FractionBoundingSchema(BaseModel): | ||
target_: Literal["anemoi.models.layers.bounding.FractionBounding"] = Field(..., alias="_target_") | ||
"Fraction bounding object defined in anemoi.models.layers.bounding." | ||
variables: list[str] | ||
"List of variables to bound using the hard tanh fraction method." | ||
min_val: float | ||
"The minimum value for the HardTanh activation. Correspond to the minimum fraction of the total_var." | ||
max_val: float | ||
"The maximum value for the HardTanh activation. Correspond to the maximum fraction of the total_var." | ||
total_var: str | ||
"Variable from which the secondary variables are derived. \ | ||
For example, convective precipitation should be a fraction of total precipitation." | ||
|
||
|
||
class HardtanhBoundingSchema(BaseModel): | ||
target_: Literal["anemoi.models.layers.bounding.HardtanhBounding"] = Field(..., alias="_target_") | ||
"Hard tanh bounding method function from anemoi.models.layers.bounding." | ||
variables: list[str] | ||
"List of variables to bound using the hard tanh method." | ||
min_val: float | ||
"The minimum value for the HardTanh activation." | ||
max_val: float | ||
"The maximum value for the HardTanh activation." | ||
|
||
|
||
defined_boundings = [ | ||
"anemoi.models.layers.bounding.HardtanhBounding", | ||
"anemoi.models.layers.bounding.FractionBounding", | ||
"anemoi.models.layers.bounding.ReluBounding", | ||
] | ||
|
||
|
||
class BaseModelConfig(BaseModel): | ||
num_channels: NonNegativeInt = Field(default=512) | ||
"Feature tensor size in the hidden space." | ||
model: Model = Field(default_factory=Model) | ||
"Model schema." | ||
trainable_parameters: TrainableParameters = Field(default_factory=TrainableParameters) | ||
"Learnable node and edge parameters." | ||
bounding: list[ReluBoundingSchema | HardtanhBoundingSchema | FractionBoundingSchema | Any] | ||
"List of bounding configuration applied in order to the specified variables." | ||
|
||
@field_validator("bounding") | ||
@classmethod | ||
def validate_bounding_schema_exist(cls, boundings: list) -> list: | ||
for bounding in boundings: | ||
if bounding["_target_"] not in defined_boundings: | ||
LOGGER.warning("%s bounding schema is not defined in anemoi.", bounding["_target_"]) | ||
return boundings | ||
|
||
|
||
class GNNConfig(BaseModelConfig): | ||
processor: GNNProcessor = Field(default_factory=GNNProcessor) | ||
"GNN processor schema." | ||
encoder: GNNEncoder = Field(default_factory=GNNEncoder) | ||
"GNN encoder schema." | ||
decoder: GNNDecoder = Field(default_factory=GNNDecoder) | ||
"GNN decoder schema." | ||
|
||
|
||
class GraphTransformerConfig(BaseModelConfig): | ||
processor: GraphTransformerProcessor = Field(default_factory=GraphTransformerProcessor) | ||
"Graph transformer processor schema." | ||
encoder: GraphTransformerEncoder = Field(default_factory=GraphTransformerEncoder) | ||
"Graph transformer encoder schema." | ||
decoder: GraphTransformerDecoder = Field(default_factory=GraphTransformerDecoder) | ||
"Graph transformer decoder schema." | ||
|
||
|
||
class TransformerConfig(BaseModelConfig): | ||
processor: TransformerProcessor = Field(default_factory=TransformerProcessor) | ||
"Transformer processor schema." | ||
encoder: GraphTransformerEncoder = Field(default_factory=GraphTransformerEncoder) | ||
"Graph transformer encoder schema." | ||
decoder: GraphTransformerDecoder = Field(default_factory=GraphTransformerDecoder) | ||
"Graph transformer decoder schema." |
Oops, something went wrong.