Skip to content

Commit

Permalink
small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Feb 20, 2025
1 parent 84b7b31 commit ce5e0ac
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 42 deletions.
22 changes: 16 additions & 6 deletions training/src/anemoi/training/losses/scalers/base_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from anemoi.models.data_indices.collection import IndexCollection

LOGGER = logging.getLogger(__name__)
SCALER_DTYPE = tuple[tuple[int], np.ndarray]


class ScaleDimABCMeta(ABCMeta):
Expand Down Expand Up @@ -74,16 +75,12 @@ def __init_subclass__(cls, **kwargs):
error_msg = f"Class {cls.__name__} must define 'scale_dims'"
raise TypeError(error_msg)

@property
def is_variable_dim_scaled(self) -> bool:
return -1 in self.scale_dims or 3 in self.scale_dims

@property
def is_spatial_dim_scaled(self) -> bool:
return -2 in self.scale_dims or 2 in self.scale_dims
return self.scale_dims is not None and (-2 in self.scale_dims or 2 in self.scale_dims)

@abstractmethod
def get_scaling(self, **kwargs) -> np.ndarray:
def get_scaling_values(self, **kwargs) -> np.ndarray:
"""Abstract method to get loss scaling."""
...

Expand All @@ -100,6 +97,11 @@ def normalise(self, values: np.ndarray) -> np.ndarray:
error_msg = f"{self.norm} must be one of: None, unit-sum, l1, unit-mean."
raise ValueError(error_msg)

def get_scaling(self) -> SCALER_DTYPE:
scaler_values = self.get_scaling_values()
scaler_values = self.normalise(scaler_values)
return self.scale_dims, scaler_values


class BaseDelayedScaler(BaseScaler, metaclass=ScaleDimABCMeta):
"""Base class for delayed Scalers.
Expand All @@ -108,3 +110,11 @@ class BaseDelayedScaler(BaseScaler, metaclass=ScaleDimABCMeta):
computed during the first iteration of the training loop. This delayed scalers are suitable
for scalers requiring information from the `model.pre_processors`.
"""

@abstractmethod
def get_delayed_scaling_values(self, **kwargs) -> np.ndarray: ...

def get_delayed_scaling(self) -> SCALER_DTYPE:
scaler_values = self.get_delayed_scaling_values()
scaler_values = self.normalise(scaler_values)
return self.scale_dims, scaler_values
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def __init__(self, data_indices: IndexCollection, norm: str | None = None, **kwa
super().__init__(data_indices, norm=norm)
del kwargs

def get_scaling(self, model: AnemoiModelInterface) -> np.ndarray:
def get_scaling_values(self) -> np.ndarray:
return np.ones(tuple([1] * len(self.scale_dims)))

def get_delayed_scaling_values(self, model: AnemoiModelInterface) -> np.ndarray:
"""Get loss scaling.
Get mask multiplying NaN locations with zero.
Expand Down
21 changes: 11 additions & 10 deletions training/src/anemoi/training/losses/scalers/node_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import torch

from anemoi.training.losses.scalers.base_scaler import BaseScaler
from anemoi.training.utils.masks import NoOutputMask

if TYPE_CHECKING:
import numpy as np
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.training.utils.masks import BaseMask

LOGGER = logging.getLogger(__name__)

Expand All @@ -36,7 +38,7 @@ def __init__(
graph_data: HeteroData,
nodes_name: str,
nodes_attribute_name: str | None = None,
apply_output_mask: bool = False,
output_mask: type[BaseMask] = None,
inverse: bool = False,
norm: str | None = None,
**kwargs,
Expand All @@ -51,25 +53,24 @@ def __init__(
Name of the nodes in the graph.
nodes_attribute_name : str | None, optional
Name of the node attribute to use for scaling, by default None
apply_output_mask : bool, optional
output_mask : type[BaseMask], optional
Whether to apply output mask to the scaling, by default False
norm : str, optional
Type of normalization to apply. Options are None, unit-sum, unit-mean and l1.
**kwargs : dict
Additional keyword arguments.
"""
self.apply_output_mask = apply_output_mask
self.output_mask = output_mask if output_mask is not None else NoOutputMask()
self.nodes = graph_data[nodes_name]
self.nodes_attribute_name = nodes_attribute_name
self.inverse = inverse
super().__init__(data_indices, norm=norm)
del kwargs

def get_scaling(self, **_kwargs) -> np.ndarray:
if self.inverse:
return ~self.nodes[self.nodes_attribute_name].squeeze().numpy()

return self.nodes[self.nodes_attribute_name].squeeze().numpy()
def get_scaling_values(self) -> np.ndarray:
scaler_values = self.nodes[self.nodes_attribute_name].squeeze().numpy()
scaler_values = ~scaler_values if self.inverse else scaler_values
return self.output_mask.apply(scaler_values, dim=0, fill_value=0.0)


class ReweightedGraphNodeAttributeScaler(GraphNodeAttributeScaler):
Expand Down Expand Up @@ -123,6 +124,6 @@ def reweight_attribute_values(self, values: np.ndarray) -> np.ndarray:
)
return values

def get_scaling(self, **kwargs) -> np.ndarray:
attribute_values = super().get_scaling(**kwargs)
def get_scaling_values(self, **kwargs) -> np.ndarray:
attribute_values = super().get_scaling_values(**kwargs)
return self.reweight_attribute_values(attribute_values)
17 changes: 2 additions & 15 deletions training/src/anemoi/training/losses/scalers/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
from typing import TYPE_CHECKING

import numpy as np
from hydra.utils import instantiate

from anemoi.training.losses.scalers.base_scaler import BaseDelayedScaler
Expand All @@ -21,17 +20,15 @@
import torch

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.training.utils.masks import BaseMask
from anemoi.training.losses.scalers.base_scaler import SCALER_DTYPE
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)
SCALER_DTYPE = tuple[tuple[int], np.ndarray]


def create_scalers(
scalers_config: DotDict,
data_indices: IndexCollection,
output_mask: BaseMask,
**kwargs,
) -> tuple[dict[str, SCALER_DTYPE], dict[str, SCALER_DTYPE]]:
scalers, delayed_scaler_builders = {}, {}
Expand All @@ -40,18 +37,8 @@ def create_scalers(

if isinstance(scaler_builder, BaseDelayedScaler):
delayed_scaler_builders[name] = scaler_builder
scalers[name] = (scaler_builder.scale_dims, np.ones(tuple([1] * len(scaler_builder.scale_dims))))
continue

scaler_values = scaler_builder.get_scaling()

# If a scaler needs to apply the output mask (LAM) after its creation,
# it must include the apply_output_mask attribue.
if scaler_builder.is_spatial_dim_scaled and getattr(scaler_builder, "apply_output_mask", False):
scaler_values = output_mask.apply(scaler_values, dim=0, fill_value=0.0)

scaler_values = scaler_builder.normalise(scaler_values)
scalers[name] = (scaler_builder.scale_dims, scaler_values)
scalers[name] = scaler_builder.get_scaling()

print_final_variable_scaling(scalers, data_indices)

Expand Down
6 changes: 2 additions & 4 deletions training/src/anemoi/training/losses/scalers/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def get_variable_group(self, variable_name: str) -> tuple[str, str, int]:
Variable level, i.e. pressure level or model level
"""
return self.extract_variable_group_and_level.get_group_and_level(
variable_name,
)
return self.extract_variable_group_and_level.get_group_and_level(variable_name)


class GeneralVariableLossScaler(BaseVariableLossScaler):
Expand Down Expand Up @@ -117,7 +115,7 @@ def __init__(
self.weights = weights
del kwargs

def get_scaling(self, **_kwargs) -> np.ndarray:
def get_scaling_values(self, **_kwargs) -> np.ndarray:
"""Get loss scaling.
Retrieve the loss scaling for each variable from the config file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_level_scaling(self, variable_level: int) -> float:
"""
...

def get_scaling(self, **_kwargs) -> np.ndarray:
def get_scaling_values(self, **_kwargs) -> np.ndarray:
variable_level_scaling = np.ones((len(self.data_indices.internal_data.output.full),), dtype=np.float32)

LOGGER.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
@abstractmethod
def get_level_scaling(self, variable_level: int) -> float: ...

def get_scaling(self, **_kwargs) -> np.ndarray:
def get_scaling_values(self, **_kwargs) -> np.ndarray:
variable_level_scaling = np.ones((len(self.data_indices.internal_data.output.full),), dtype=np.float32)

LOGGER.info("Variable Level Scaling: Applying %s scaling to prognostic variables", self.__class__.__name__)
Expand Down
6 changes: 2 additions & 4 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def define_delayed_scalers(self) -> None:
"""Update delayed scalers such as the loss weights mask for imputed variables."""
for name, scaler_builder in self.delayed_scaler_builders.items():
scaler_values = scaler_builder.get_scaling(model=self.model)
scaler_values = scaler_builder.normalise(scaler_values)
self.scalers[name] = (scaler_builder.scale_dims, scaler_values)
self.loss.update_scaler(scaler=scaler_values, name=name)
self.scalers[name] = scaler_builder.get_delayed_scaling(model=self.model)
self.loss.update_scaler(scaler=self.scalers[name][1], name=name)

def set_model_comm_group(
self,
Expand Down

0 comments on commit ce5e0ac

Please sign in to comment.