diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 005a9bc4..45cc3564 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -1,7 +1,7 @@ activation: GELU num_channels: 512 cpu_offload: False -output_mask: +output_mask: _target_: anemoi.training.utils.masks.NoOutputMask model: diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 3b378144..9733cbd7 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -1,7 +1,7 @@ activation: GELU num_channels: 1024 cpu_offload: False -output_mask: +output_mask: _target_: anemoi.training.utils.masks.NoOutputMask model: diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index cf32e40e..ad1f989a 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -1,7 +1,7 @@ activation: GELU num_channels: 1024 cpu_offload: False -output_mask: +output_mask: _target_: anemoi.training.utils.masks.NoOutputMask model: diff --git a/training/src/anemoi/training/schemas/models/models.py b/training/src/anemoi/training/schemas/models/models.py index 9122f391..e631618d 100644 --- a/training/src/anemoi/training/schemas/models/models.py +++ b/training/src/anemoi/training/schemas/models/models.py @@ -106,6 +106,7 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR Field(discriminator="target_"), ] + class NoOutputMaskSchema(BaseModel): target_: Literal["anemoi.training.utils.masks.NoOutputMask"] = Field(..., alias="_target_") @@ -129,7 +130,7 @@ class ModelSchema(PydanticBaseModel): "Learnable node and edge parameters." bounding: list[Bounding] "List of bounding configuration applied in order to the specified variables." - output_mask: OutputMaskSchemas # !TODO CHECK! + output_mask: OutputMaskSchemas # !TODO CHECK! "Output mask" processor: Union[GNNProcessorSchema, GraphTransformerProcessorSchema, TransformerProcessorSchema] = Field( diff --git a/training/src/anemoi/training/train/forecaster.py b/training/src/anemoi/training/train/forecaster.py index 632000a7..b17b9aaf 100644 --- a/training/src/anemoi/training/train/forecaster.py +++ b/training/src/anemoi/training/train/forecaster.py @@ -13,21 +13,19 @@ import pytorch_lightning as pl import torch +from hydra.utils import instantiate from timm.scheduler import CosineLRScheduler from torch.distributed.optim import ZeroRedundancyOptimizer from torch.utils.checkpoint import checkpoint from anemoi.models.interface import AnemoiModelInterface from anemoi.training.losses.base import BaseLoss -from hydra.utils import instantiate from anemoi.training.losses.loss import get_loss_function from anemoi.training.losses.loss import get_metric_ranges from anemoi.training.losses.scalers.scaling import create_scalers from anemoi.training.losses.utils import grad_scaler from anemoi.training.schemas.base_schema import BaseSchema from anemoi.training.schemas.base_schema import convert_to_omegaconf -from anemoi.training.utils.masks import Boolean1DMask -from anemoi.training.utils.masks import NoOutputMask if TYPE_CHECKING: from collections.abc import Generator diff --git a/training/src/anemoi/training/utils/masks.py b/training/src/anemoi/training/utils/masks.py index a244fe95..a908c371 100644 --- a/training/src/anemoi/training/utils/masks.py +++ b/training/src/anemoi/training/utils/masks.py @@ -21,9 +21,9 @@ class BaseMask: """Base class for masking model output.""" + def __init__(self, *_args, **_kwargs) -> None: """Initialize base mask.""" - pass @property def supporting_arrays(self) -> dict: