From ce5e0ac0cb52953ad260007d426e37df45698ecb Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 20 Feb 2025 12:00:44 +0000 Subject: [PATCH] small refactor --- .../training/losses/scalers/base_scaler.py | 22 ++++++++++++++----- .../losses/scalers/loss_weights_mask.py | 5 ++++- .../losses/scalers/node_attributes.py | 21 +++++++++--------- .../anemoi/training/losses/scalers/scaling.py | 17 ++------------ .../training/losses/scalers/variable.py | 6 ++--- .../training/losses/scalers/variable_level.py | 2 +- .../losses/scalers/variable_tendency.py | 2 +- .../src/anemoi/training/train/forecaster.py | 6 ++--- 8 files changed, 39 insertions(+), 42 deletions(-) diff --git a/training/src/anemoi/training/losses/scalers/base_scaler.py b/training/src/anemoi/training/losses/scalers/base_scaler.py index d1e9ce7b..79fc40c0 100644 --- a/training/src/anemoi/training/losses/scalers/base_scaler.py +++ b/training/src/anemoi/training/losses/scalers/base_scaler.py @@ -20,6 +20,7 @@ from anemoi.models.data_indices.collection import IndexCollection LOGGER = logging.getLogger(__name__) +SCALER_DTYPE = tuple[tuple[int], np.ndarray] class ScaleDimABCMeta(ABCMeta): @@ -74,16 +75,12 @@ def __init_subclass__(cls, **kwargs): error_msg = f"Class {cls.__name__} must define 'scale_dims'" raise TypeError(error_msg) - @property - def is_variable_dim_scaled(self) -> bool: - return -1 in self.scale_dims or 3 in self.scale_dims - @property def is_spatial_dim_scaled(self) -> bool: - return -2 in self.scale_dims or 2 in self.scale_dims + return self.scale_dims is not None and (-2 in self.scale_dims or 2 in self.scale_dims) @abstractmethod - def get_scaling(self, **kwargs) -> np.ndarray: + def get_scaling_values(self, **kwargs) -> np.ndarray: """Abstract method to get loss scaling.""" ... @@ -100,6 +97,11 @@ def normalise(self, values: np.ndarray) -> np.ndarray: error_msg = f"{self.norm} must be one of: None, unit-sum, l1, unit-mean." raise ValueError(error_msg) + def get_scaling(self) -> SCALER_DTYPE: + scaler_values = self.get_scaling_values() + scaler_values = self.normalise(scaler_values) + return self.scale_dims, scaler_values + class BaseDelayedScaler(BaseScaler, metaclass=ScaleDimABCMeta): """Base class for delayed Scalers. @@ -108,3 +110,11 @@ class BaseDelayedScaler(BaseScaler, metaclass=ScaleDimABCMeta): computed during the first iteration of the training loop. This delayed scalers are suitable for scalers requiring information from the `model.pre_processors`. """ + + @abstractmethod + def get_delayed_scaling_values(self, **kwargs) -> np.ndarray: ... + + def get_delayed_scaling(self) -> SCALER_DTYPE: + scaler_values = self.get_delayed_scaling_values() + scaler_values = self.normalise(scaler_values) + return self.scale_dims, scaler_values diff --git a/training/src/anemoi/training/losses/scalers/loss_weights_mask.py b/training/src/anemoi/training/losses/scalers/loss_weights_mask.py index 4dd59713..8b22272e 100644 --- a/training/src/anemoi/training/losses/scalers/loss_weights_mask.py +++ b/training/src/anemoi/training/losses/scalers/loss_weights_mask.py @@ -40,7 +40,10 @@ def __init__(self, data_indices: IndexCollection, norm: str | None = None, **kwa super().__init__(data_indices, norm=norm) del kwargs - def get_scaling(self, model: AnemoiModelInterface) -> np.ndarray: + def get_scaling_values(self) -> np.ndarray: + return np.ones(tuple([1] * len(self.scale_dims))) + + def get_delayed_scaling_values(self, model: AnemoiModelInterface) -> np.ndarray: """Get loss scaling. Get mask multiplying NaN locations with zero. diff --git a/training/src/anemoi/training/losses/scalers/node_attributes.py b/training/src/anemoi/training/losses/scalers/node_attributes.py index 3b7a04f9..03ef40f2 100644 --- a/training/src/anemoi/training/losses/scalers/node_attributes.py +++ b/training/src/anemoi/training/losses/scalers/node_attributes.py @@ -15,12 +15,14 @@ import torch from anemoi.training.losses.scalers.base_scaler import BaseScaler +from anemoi.training.utils.masks import NoOutputMask if TYPE_CHECKING: import numpy as np from torch_geometric.data import HeteroData from anemoi.models.data_indices.collection import IndexCollection + from anemoi.training.utils.masks import BaseMask LOGGER = logging.getLogger(__name__) @@ -36,7 +38,7 @@ def __init__( graph_data: HeteroData, nodes_name: str, nodes_attribute_name: str | None = None, - apply_output_mask: bool = False, + output_mask: type[BaseMask] = None, inverse: bool = False, norm: str | None = None, **kwargs, @@ -51,25 +53,24 @@ def __init__( Name of the nodes in the graph. nodes_attribute_name : str | None, optional Name of the node attribute to use for scaling, by default None - apply_output_mask : bool, optional + output_mask : type[BaseMask], optional Whether to apply output mask to the scaling, by default False norm : str, optional Type of normalization to apply. Options are None, unit-sum, unit-mean and l1. **kwargs : dict Additional keyword arguments. """ - self.apply_output_mask = apply_output_mask + self.output_mask = output_mask if output_mask is not None else NoOutputMask() self.nodes = graph_data[nodes_name] self.nodes_attribute_name = nodes_attribute_name self.inverse = inverse super().__init__(data_indices, norm=norm) del kwargs - def get_scaling(self, **_kwargs) -> np.ndarray: - if self.inverse: - return ~self.nodes[self.nodes_attribute_name].squeeze().numpy() - - return self.nodes[self.nodes_attribute_name].squeeze().numpy() + def get_scaling_values(self) -> np.ndarray: + scaler_values = self.nodes[self.nodes_attribute_name].squeeze().numpy() + scaler_values = ~scaler_values if self.inverse else scaler_values + return self.output_mask.apply(scaler_values, dim=0, fill_value=0.0) class ReweightedGraphNodeAttributeScaler(GraphNodeAttributeScaler): @@ -123,6 +124,6 @@ def reweight_attribute_values(self, values: np.ndarray) -> np.ndarray: ) return values - def get_scaling(self, **kwargs) -> np.ndarray: - attribute_values = super().get_scaling(**kwargs) + def get_scaling_values(self, **kwargs) -> np.ndarray: + attribute_values = super().get_scaling_values(**kwargs) return self.reweight_attribute_values(attribute_values) diff --git a/training/src/anemoi/training/losses/scalers/scaling.py b/training/src/anemoi/training/losses/scalers/scaling.py index e72ab4c7..dc15ad09 100644 --- a/training/src/anemoi/training/losses/scalers/scaling.py +++ b/training/src/anemoi/training/losses/scalers/scaling.py @@ -12,7 +12,6 @@ import logging from typing import TYPE_CHECKING -import numpy as np from hydra.utils import instantiate from anemoi.training.losses.scalers.base_scaler import BaseDelayedScaler @@ -21,17 +20,15 @@ import torch from anemoi.models.data_indices.collection import IndexCollection - from anemoi.training.utils.masks import BaseMask + from anemoi.training.losses.scalers.base_scaler import SCALER_DTYPE from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) -SCALER_DTYPE = tuple[tuple[int], np.ndarray] def create_scalers( scalers_config: DotDict, data_indices: IndexCollection, - output_mask: BaseMask, **kwargs, ) -> tuple[dict[str, SCALER_DTYPE], dict[str, SCALER_DTYPE]]: scalers, delayed_scaler_builders = {}, {} @@ -40,18 +37,8 @@ def create_scalers( if isinstance(scaler_builder, BaseDelayedScaler): delayed_scaler_builders[name] = scaler_builder - scalers[name] = (scaler_builder.scale_dims, np.ones(tuple([1] * len(scaler_builder.scale_dims)))) - continue - scaler_values = scaler_builder.get_scaling() - - # If a scaler needs to apply the output mask (LAM) after its creation, - # it must include the apply_output_mask attribue. - if scaler_builder.is_spatial_dim_scaled and getattr(scaler_builder, "apply_output_mask", False): - scaler_values = output_mask.apply(scaler_values, dim=0, fill_value=0.0) - - scaler_values = scaler_builder.normalise(scaler_values) - scalers[name] = (scaler_builder.scale_dims, scaler_values) + scalers[name] = scaler_builder.get_scaling() print_final_variable_scaling(scalers, data_indices) diff --git a/training/src/anemoi/training/losses/scalers/variable.py b/training/src/anemoi/training/losses/scalers/variable.py index 04b36af4..7bd3b9ec 100644 --- a/training/src/anemoi/training/losses/scalers/variable.py +++ b/training/src/anemoi/training/losses/scalers/variable.py @@ -79,9 +79,7 @@ def get_variable_group(self, variable_name: str) -> tuple[str, str, int]: Variable level, i.e. pressure level or model level """ - return self.extract_variable_group_and_level.get_group_and_level( - variable_name, - ) + return self.extract_variable_group_and_level.get_group_and_level(variable_name) class GeneralVariableLossScaler(BaseVariableLossScaler): @@ -117,7 +115,7 @@ def __init__( self.weights = weights del kwargs - def get_scaling(self, **_kwargs) -> np.ndarray: + def get_scaling_values(self, **_kwargs) -> np.ndarray: """Get loss scaling. Retrieve the loss scaling for each variable from the config file. diff --git a/training/src/anemoi/training/losses/scalers/variable_level.py b/training/src/anemoi/training/losses/scalers/variable_level.py index 2b05fb9d..d0d9f386 100644 --- a/training/src/anemoi/training/losses/scalers/variable_level.py +++ b/training/src/anemoi/training/losses/scalers/variable_level.py @@ -80,7 +80,7 @@ def get_level_scaling(self, variable_level: int) -> float: """ ... - def get_scaling(self, **_kwargs) -> np.ndarray: + def get_scaling_values(self, **_kwargs) -> np.ndarray: variable_level_scaling = np.ones((len(self.data_indices.internal_data.output.full),), dtype=np.float32) LOGGER.info( diff --git a/training/src/anemoi/training/losses/scalers/variable_tendency.py b/training/src/anemoi/training/losses/scalers/variable_tendency.py index 44997923..c9fc85f3 100644 --- a/training/src/anemoi/training/losses/scalers/variable_tendency.py +++ b/training/src/anemoi/training/losses/scalers/variable_tendency.py @@ -64,7 +64,7 @@ def __init__( @abstractmethod def get_level_scaling(self, variable_level: int) -> float: ... - def get_scaling(self, **_kwargs) -> np.ndarray: + def get_scaling_values(self, **_kwargs) -> np.ndarray: variable_level_scaling = np.ones((len(self.data_indices.internal_data.output.full),), dtype=np.float32) LOGGER.info("Variable Level Scaling: Applying %s scaling to prognostic variables", self.__class__.__name__) diff --git a/training/src/anemoi/training/train/forecaster.py b/training/src/anemoi/training/train/forecaster.py index 7c68ff03..f503e130 100644 --- a/training/src/anemoi/training/train/forecaster.py +++ b/training/src/anemoi/training/train/forecaster.py @@ -175,10 +175,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def define_delayed_scalers(self) -> None: """Update delayed scalers such as the loss weights mask for imputed variables.""" for name, scaler_builder in self.delayed_scaler_builders.items(): - scaler_values = scaler_builder.get_scaling(model=self.model) - scaler_values = scaler_builder.normalise(scaler_values) - self.scalers[name] = (scaler_builder.scale_dims, scaler_values) - self.loss.update_scaler(scaler=scaler_values, name=name) + self.scalers[name] = scaler_builder.get_delayed_scaling(model=self.model) + self.loss.update_scaler(scaler=self.scalers[name][1], name=name) def set_model_comm_group( self,