Skip to content

Commit

Permalink
checkout main
Browse files Browse the repository at this point in the history
  • Loading branch information
jjlk committed Feb 28, 2025
1 parent f300160 commit 44f90df
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions models/src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import NamedNodesAttributes
from anemoi.models.layers.utils import load_layer_kernels
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)
Expand All @@ -35,6 +36,7 @@ def __init__(
*,
model_config: DotDict,
data_indices: dict,
statistics: dict,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.
Expand All @@ -49,14 +51,15 @@ def __init__(
Graph definition
"""
super().__init__()

model_config = DotDict(model_config)
self._graph_data = graph_data
self._graph_name_data = model_config.graph.data
self._graph_name_hidden = model_config.graph.hidden

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)
self.data_indices = data_indices
self.statistics = statistics

self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels
Expand All @@ -65,6 +68,9 @@ def __init__(

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# read config.model.layer_kernels to get the implementation for certain layers
self.layer_kernels = load_layer_kernels(model_config.get("model.layer_kernels", {}))

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
Expand All @@ -74,6 +80,7 @@ def __init__(
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
layer_kernels=self.layer_kernels,
)

# Processor hidden -> hidden
Expand All @@ -83,6 +90,7 @@ def __init__(
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
layer_kernels=self.layer_kernels,
)

# Decoder hidden -> data
Expand All @@ -95,12 +103,18 @@ def __init__(
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
layer_kernels=self.layer_kernels,
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index)
instantiate(
cfg,
name_to_index=self.data_indices.internal_model.output.name_to_index,
statistics=self.statistics,
name_to_index_stats=self.data_indices.data.input.name_to_index,
)
for cfg in getattr(model_config.model, "bounding", [])
]
)
Expand Down

0 comments on commit 44f90df

Please sign in to comment.