Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 21, 2025
1 parent 04ed7d8 commit fded990
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion training/src/anemoi/training/config/model/gnn.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
activation: GELU
num_channels: 512
cpu_offload: False
output_mask:
output_mask:
_target_: anemoi.training.utils.masks.NoOutputMask

model:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
activation: GELU
num_channels: 1024
cpu_offload: False
output_mask:
output_mask:
_target_: anemoi.training.utils.masks.NoOutputMask

model:
Expand Down
2 changes: 1 addition & 1 deletion training/src/anemoi/training/config/model/transformer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
activation: GELU
num_channels: 1024
cpu_offload: False
output_mask:
output_mask:
_target_: anemoi.training.utils.masks.NoOutputMask

model:
Expand Down
3 changes: 2 additions & 1 deletion training/src/anemoi/training/schemas/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR
Field(discriminator="target_"),
]


class NoOutputMaskSchema(BaseModel):
target_: Literal["anemoi.training.utils.masks.NoOutputMask"] = Field(..., alias="_target_")

Expand All @@ -129,7 +130,7 @@ class ModelSchema(PydanticBaseModel):
"Learnable node and edge parameters."
bounding: list[Bounding]
"List of bounding configuration applied in order to the specified variables."
output_mask: OutputMaskSchemas # !TODO CHECK!
output_mask: OutputMaskSchemas # !TODO CHECK!
"Output mask"

processor: Union[GNNProcessorSchema, GraphTransformerProcessorSchema, TransformerProcessorSchema] = Field(
Expand Down
4 changes: 1 addition & 3 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@

import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
from timm.scheduler import CosineLRScheduler
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint

from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.base import BaseLoss
from hydra.utils import instantiate
from anemoi.training.losses.loss import get_loss_function
from anemoi.training.losses.loss import get_metric_ranges
from anemoi.training.losses.scalers.scaling import create_scalers
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.schemas.base_schema import BaseSchema
from anemoi.training.schemas.base_schema import convert_to_omegaconf
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down
2 changes: 1 addition & 1 deletion training/src/anemoi/training/utils/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

class BaseMask:
"""Base class for masking model output."""

def __init__(self, *_args, **_kwargs) -> None:
"""Initialize base mask."""
pass

@property
def supporting_arrays(self) -> dict:
Expand Down

0 comments on commit fded990

Please sign in to comment.