Skip to content

Commit

Permalink
feat: minimal graph config supported
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Feb 10, 2025
1 parent db0217b commit 6a507c4
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions training/src/anemoi/training/schemas/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from pydantic import model_validator

from anemoi.training.schemas.utils import BaseModel

Expand Down Expand Up @@ -44,9 +45,9 @@ class EdgeSchema(BaseModel):


class BaseGraphSchema(PydanticBaseModel):
nodes: dict[str, NodeSchema]
nodes: dict[str, NodeSchema] | None = Field(default=None)
"Nodes schema for all types of nodes (ex. data, hidden)."
edges: list[EdgeSchema]
edges: list[EdgeSchema] | None = Field(default=None)
"List of edges schema."
overwrite: bool = Field(example=True)
"whether to overwrite existing graph file. Default to True."
Expand All @@ -56,3 +57,10 @@ class BaseGraphSchema(PydanticBaseModel):
hidden: str = Field(example="hidden")
"Key name for the hidden nodes. Default to 'hidden'."
# TODO(Helen): Needs to be adjusted for more complex graph setups

@model_validator(mode="after")
def check_if_nodes_edges_present_if_overwrite(self) -> BaseGraphSchema:
if self.overwrite and ("nodes" not in self.model_fields_set or "edges" not in self.model_fields_set):
msg = "If overwrite is True, nodes and edges must be provided."
raise ValueError(msg)
return self

0 comments on commit 6a507c4

Please sign in to comment.