Skip to content

Commit

Permalink
fix(training): Rework Combined Loss (#103)
Browse files Browse the repository at this point in the history
Fix CombinedLoss
  • Loading branch information
HCookie authored Feb 20, 2025
1 parent ef1e76e commit b63f1aa
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 57 deletions.
34 changes: 27 additions & 7 deletions training/docs/modules/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,36 @@ losses above.
.. code:: yaml
training_loss:
__target__: anemoi.training.losses.combined.CombinedLoss
_target_: anemoi.training.losses.combined.CombinedLoss
losses:
- __target__: anemoi.training.losses.mse.WeightedMSELoss
- __target__: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['variable']
- _target_: anemoi.training.losses.mse.WeightedMSELoss
- _target_: anemoi.training.losses.mae.WeightedMAELoss
loss_weights: [1.0,0.5]
scalars: ['variable']
All extra kwargs passed to ``CombinedLoss`` are passed to each of the
loss functions, and the loss weights are used to scale the individual
losses before combining them.

If ``scalars`` is not given in the underlying loss functions, all the
scalars given to the ``CombinedLoss`` are used.

If different scalars are required for each loss, the root level scalars
of the ``CombinedLoss`` should contain all the scalars required by the
individual losses. Then the scalars for each loss can be set in the
individual loss config.

All kwargs passed to ``CombinedLoss`` are passed to each of the loss
functions, and the loss weights are used to scale the individual losses
before combining them.
.. code:: yaml
training_loss:
_target_: anemoi.training.losses.combined.CombinedLoss
losses:
- _target_: anemoi.training.losses.mse.WeightedMSELoss
scalars: ['variable']
- _target_: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['loss_weights_mask']
loss_weights: [1.0, 1.0]
scalars: ['*']
.. automodule:: anemoi.training.losses.combined
:members:
Expand Down
3 changes: 1 addition & 2 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ zero_optimizer: False
# dynamic rescaling of the loss gradient
# see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
# don't enable this by default until it's been tested and proven beneficial
loss_gradient_scaling: False

# loss function for the model
training_loss:
Expand All @@ -54,8 +55,6 @@ training_loss:

ignore_nans: False

loss_gradient_scaling: False

# Validation metrics calculation,
# This may be a list, in which case all metrics will be calculated
# and logged according to their name.
Expand Down
150 changes: 113 additions & 37 deletions training/src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,86 @@
from __future__ import annotations

import functools
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import Any

import torch
from omegaconf import DictConfig

from anemoi.training.losses.utils import ScaleTensor
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.forecaster import GraphForecaster

if TYPE_CHECKING:
from collections.abc import Callable
import torch


class CombinedLoss(torch.nn.Module):
class CombinedLoss(BaseWeightedLoss):
"""Combined Loss function."""

_initial_set_scalar: bool = False

def __init__(
self,
*extra_losses: dict[str, Any] | Callable,
losses: tuple[dict[str, Any] | Callable] | None = None,
loss_weights: tuple[int, ...],
*extra_losses: dict[str, Any] | Callable | BaseWeightedLoss,
loss_weights: tuple[int, ...] | None = None,
losses: tuple[dict[str, Any] | Callable | BaseWeightedLoss] | None = None,
**kwargs,
):
"""Combined loss function.
Allows multiple losses to be combined into a single loss function,
and the components weighted.
If a sub loss function requires additional weightings or code created tensors,
that must be `included_` for this function, and then controlled by the underlying
loss function configuration.
As the losses are designed for use within the context of the
anemoi-training configuration, `losses` work best as a dictionary.
If `losses` is a `tuple[dict]`, the `scalars` key will be extracted
before being passed to `get_loss_function`, and the `scalars` defined
in each loss only applied to the respective loss. Thereby `scalars`
added to this class will be routed correctly.
If `losses` is a `tuple[Callable]`, all `scalars` added to this class
will be added to all underlying losses.
And if `losses` is a `tuple[WeightedLoss]`, no scalars added to
this class will be added to the underlying losses, as it is
assumed that will be done by the parent function.
Parameters
----------
losses: tuple[dict[str, Any]| Callable]
Tuple of losses to initialise with `GraphForecaster.get_loss_function`.
Allows for kwargs to be passed, and weighings controlled.
*extra_losses: dict[str, Any] | Callable
losses: tuple[dict[str, Any] | Callable | BaseWeightedLoss],
if a `tuple[dict]`:
Tuple of losses to initialise with `GraphForecaster.get_loss_function`.
Allows for kwargs to be passed, and weighings controlled.
If a loss should only have some of the scalars, set `scalars` in the loss config.
If no scalars are set, all scalars added to this class will be included.
if a `tuple[Callable]`:
Will be called with `kwargs`, and all scalars added to this class added.
if a `tuple[BaseWeightedLoss]`:
Added to the loss function, and no scalars passed through.
*extra_losses: dict[str, Any] | Callable | BaseWeightedLoss],
Additional arg form of losses to include in the combined loss.
loss_weights : tuple[int, ...]
loss_weights : optional, tuple[int, ...] | None
Weights of each loss function in the combined loss.
Must be the same length as the number of losses.
If None, all losses are weighted equally.
by default None.
kwargs: Any
Additional arguments to pass to the loss functions
Additional arguments to pass to the loss functions, if not Loss.
Examples
--------
>>> CombinedLoss(
{"__target__": "anemoi.training.losses.mse.WeightedMSELoss"},
{"_target_": "anemoi.training.losses.mse.WeightedMSELoss"},
loss_weights=(1.0,),
node_weights=node_weights
)
>>> CombinedLoss(
{"_target_": "anemoi.training.losses.mse.WeightedMSELoss", "scalars":['scalar_1']},
loss_weights=(1.0,),
node_weights=node_weights
)
CombinedLoss.add_scalar(name = 'scalar_1', ...)
# Only added to the `WeightedMSELoss` if specified in it's `scalars`.
--------
>>> CombinedLoss(
losses = [anemoi.training.losses.mse.WeightedMSELoss],
Expand All @@ -69,25 +100,53 @@ def __init__(
```
training_loss:
__target__: anemoi.training.losses.combined.CombinedLoss
_target_: anemoi.training.losses.combined.CombinedLoss
losses:
- __target__: anemoi.training.losses.mse.WeightedMSELoss
- __target__: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['variable']
loss_weights: [1.0,0.5]
- _target_: anemoi.training.losses.mse.WeightedMSELoss
- _target_: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['*']
loss_weights: [1.0, 0.6]
# All scalars passed to this class will be added to each underlying loss
```
```
training_loss:
_target_: anemoi.training.losses.combined.CombinedLoss
losses:
- _target_: anemoi.training.losses.mse.WeightedMSELoss
scalars: ['variable']
- _target_: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['loss_weights_mask']
scalars: ['*']
loss_weights: [1.0, 1.0]
# Only the specified scalars will be added to each loss
```
"""
super().__init__()
super().__init__(node_weights=None)

self.losses: list[BaseWeightedLoss] = []
self._loss_scalar_specification: dict[int, list[str]] = {}

losses = (*(losses or []), *extra_losses)
if loss_weights is None:
loss_weights = (1.0,) * len(losses)

assert len(losses) == len(loss_weights), "Number of losses and weights must match"
assert len(losses) > 0, "At least one loss must be provided"

self.losses = [
GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs)
for loss in losses
]
for i, loss in enumerate(losses):

if isinstance(loss, (DictConfig, dict)):
self._loss_scalar_specification[i] = loss.pop("scalars", ["*"])
self.losses.append(GraphForecaster.get_loss_function(loss, scalars={}, **dict(kwargs)))
elif isinstance(loss, Callable):
self._loss_scalar_specification[i] = ["*"]
self.losses.append(loss(**kwargs))
else:
self._loss_scalar_specification[i] = []
self.losses.append(loss)

self.add_module(self.losses[-1].name + str(i), self.losses[-1])
self.loss_weights = loss_weights

def forward(
Expand Down Expand Up @@ -125,14 +184,31 @@ def forward(
def name(self) -> str:
return "combined_" + "_".join(getattr(loss, "name", loss.__class__.__name__.lower()) for loss in self.losses)

def __getattr__(self, name: str) -> Callable:
"""Allow access to underlying attributes of the loss functions."""
if not all(hasattr(loss, name) for loss in self.losses):
error_msg = f"Attribute {name} not found in all loss functions"
raise AttributeError(error_msg)

@functools.wraps(getattr(self.losses[0], name))
def hidden_func(*args, **kwargs) -> list[Any]:
return [getattr(loss, name)(*args, **kwargs) for loss in self.losses]

return hidden_func
@property
def scalar(self) -> ScaleTensor:
"""Get union of underlying scalars."""
scalars = {}
for loss in self.losses:
scalars.update(loss.scalar.tensors)
return ScaleTensor(scalars)

@scalar.setter
def scalar(self, _: Any) -> None:
"""Set underlying loss scalars."""
if not self._initial_set_scalar: # Allow parent class to 'initialise' the scalar
self._initial_set_scalar = True
return
excep_msg = "Cannot set `CombinedLoss` scalar directly, use `add_scalar` or `update_scalar`."
raise AttributeError(excep_msg)

@functools.wraps(ScaleTensor.add_scalar, assigned=("__doc__", "__annotations__"))
def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None:
for i, spec in self._loss_scalar_specification.items():
if "*" in spec or name in spec:
self.losses[i].scalar.add_scalar(dimension, scalar, name=name)

@functools.wraps(ScaleTensor.update_scalar, assigned=("__doc__", "__annotations__"))
def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None:
for i, spec in self._loss_scalar_specification.items():
if "*" in spec or name in spec:
self.losses[i].scalar.update_scalar(name, scalar=scalar, override=override)
6 changes: 6 additions & 0 deletions training/src/anemoi/training/losses/weightedloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def __init__(
- self.node_weights: torch.Tensor of shape (N, )
- self.scalar: ScaleTensor modified with `add_scalar` and `update_scalar`
These losses are designed for use within the context of
the anemoi-training configuration, where scalars are added
after initialisation. If being used outside of this
context, call `add_scalar` and `update_scalar` to add or
update the scale tensors.
Parameters
----------
node_weights : torch.Tensor of shape (N, )
Expand Down
32 changes: 25 additions & 7 deletions training/src/anemoi/training/schemas/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from enum import Enum
from functools import partial
from typing import Annotated
from typing import Any
from typing import Literal
from typing import Union

Expand All @@ -20,6 +21,7 @@
from pydantic import NonNegativeFloat
from pydantic import NonNegativeInt
from pydantic import PositiveInt
from pydantic import field_validator
from pydantic import model_validator

from anemoi.training.schemas.utils import allowed_values
Expand Down Expand Up @@ -106,14 +108,15 @@ class PressureLevelScalerSchema(BaseModel):
"Slope of the scaling function."


PossibleScalars = Annotated[str, AfterValidator(partial(allowed_values, values=["variable", "loss_weights_mask"]))]
PossibleScalars = Annotated[str, AfterValidator(partial(allowed_values, values=["variable", "loss_weights_mask", "*"]))]


class ImplementedLossesUsingBaseLossSchema(str, Enum):
rmse = "anemoi.training.losses.rmse.WeightedRMSELoss"
mse = "anemoi.training.losses.mse.WeightedMSELoss"
mae = "anemoi.training.losses.mae.WeightedMAELoss"
logcosh = "anemoi.training.losses.logcosh.WeightedLogCoshLoss"
huber = "anemoi.training.losses.huber.WeightedHuberLoss"


class BaseLossSchema(BaseModel):
Expand All @@ -137,18 +140,33 @@ class WeightedMSELossLimitedAreaSchema(BaseLossSchema):
"Whether to compute the contribution to the MSE or not."


class CombinedLossSchema(BaseModel):
class CombinedLossSchema(BaseLossSchema):
target_: Literal["anemoi.training.losses.combined.CombinedLoss"] = Field(..., alias="_target_")
losses: list[BaseLossSchema] = Field(min_length=1)
loss_weights: list[Union[int, float]] = Field(min_length=1)
"Losses to combine, can be any of the normal losses."
loss_weights: Union[list[Union[int, float]], None] = None
"Weightings of losses, if not set, all losses are weighted equally."

@field_validator("losses", mode="before")
@classmethod
def add_empty_scalars(cls, losses: Any) -> Any:
"""Add empty scalars to loss functions, as scalars can be set at top level."""
from omegaconf.omegaconf import open_dict

for loss in losses:
if "scalars" not in loss:
with open_dict(loss):
loss["scalars"] = []
return losses

@model_validator(mode="after")
def check_length_of_weights_and_losses(self, values: dict) -> CombinedLossSchema:
losses, loss_weights = values["losses"], values["loss_weights"]
if len(losses) != len(loss_weights):
def check_length_of_weights_and_losses(self) -> CombinedLossSchema:
"""Check that the number of losses and weights match, or if not set, skip."""
losses, loss_weights = self.losses, self.loss_weights
if loss_weights is not None and len(losses) != len(loss_weights):
error_msg = "Number of losses and weights must match"
raise ValueError(error_msg)
return values
return self


LossSchemas = Union[BaseLossSchema, HuberLossSchema, WeightedMSELossLimitedAreaSchema, CombinedLossSchema]
Expand Down
11 changes: 7 additions & 4 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def get_loss_function(
----------
config : DictConfig
Loss function configuration, should include `scalars` if scalars are to be added to the loss function.
scalars : Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None], optional
Scalars which can be added to the loss function. Defaults to None., by default None
scalars : dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]] | None,
Scalars which can be added to the loss function. Defaults to None.,
If a scalar is to be added to the loss, ensure it is in `scalars` in the loss config
E.g.
If `scalars: ['variable']` is set in the config, and `variable` in `scalars`
Expand All @@ -232,6 +232,8 @@ def get_loss_function(
ValueError
If scalar is not found in valid scalars
"""
scalars = scalars or {}

if isinstance(config, ListConfig):
return torch.nn.ModuleList(
[
Expand All @@ -249,14 +251,15 @@ def get_loss_function(
scalars_to_include = loss_config.pop("scalars", [])

# Instantiate the loss function with the loss_init_config
kwargs["_recursive_"] = kwargs.get("_recursive_", False)
loss_function = instantiate(loss_config, **kwargs)

if not isinstance(loss_function, BaseWeightedLoss):
error_msg = f"Loss must be a subclass of 'BaseWeightedLoss', not {type(loss_function)}"
error_msg = f"Loss must be a subclass of `BaseWeightedLoss`, not {type(loss_function)}"
raise TypeError(error_msg)

for key in scalars_to_include:
if key not in scalars or []:
if key not in scalars:
error_msg = f"Scalar {key!r} not found in valid scalars: {list(scalars.keys())}"
raise ValueError(error_msg)
loss_function.add_scalar(*scalars[key], name=key)
Expand Down
Loading

0 comments on commit b63f1aa

Please sign in to comment.