Skip to content

Commit

Permalink
missing schema for bounding option
Browse files Browse the repository at this point in the history
  • Loading branch information
sahahner committed Feb 18, 2025
1 parent 83d72e1 commit ed99fe1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions training/src/anemoi/training/schemas/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ class HardtanhBoundingSchema(BaseModel):
"The maximum value for the HardTanh activation."


class NormalizedReluBounding(BaseModel):
class NormalizedReluBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.NormalizedReluBounding"] = Field(..., alias="_target_")
variables: list[str]
min_val: list[float]
normalizer: list[str]

@model_validator(mode="after")
def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedReluBounding:
def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedReluBoundingSchema:
error_msg = f"""{self.__class__} requires that number of normalizers ({len(self.normalizer)}) or
match the number of variables ({len(self.variables)})"""
assert len(self.normalizer) == len(self.variables), error_msg
Expand All @@ -102,7 +102,7 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR


Bounding = Annotated[
Union[ReluBoundingSchema, FractionBoundingSchema, HardtanhBoundingSchema, NormalizedReluBounding],
Union[ReluBoundingSchema, FractionBoundingSchema, HardtanhBoundingSchema, NormalizedReluBoundingSchema],
Field(discriminator="target_"),
]

Expand Down

0 comments on commit ed99fe1

Please sign in to comment.