Skip to content

Commit

Permalink
update combined loss
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Feb 28, 2025
1 parent d860cf8 commit 4f55c4f
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions training/src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from anemoi.training.losses.base import BaseLoss
from anemoi.training.losses.utils import ScaleTensor
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.losses.loss import get_loss_function

if TYPE_CHECKING:
import torch
Expand All @@ -27,7 +27,7 @@
class CombinedLoss(BaseLoss):
"""Combined Loss function."""

_initial_set_scalar: bool = False
_initial_set_scaler: bool = False

def __init__(
self,
Expand All @@ -44,13 +44,13 @@ def __init__(
As the losses are designed for use within the context of the
anemoi-training configuration, `losses` work best as a dictionary.
If `losses` is a `tuple[dict]`, the `scalars` key will be extracted
before being passed to `get_loss_function`, and the `scalars` defined
in each loss only applied to the respective loss. Thereby `scalars`
If `losses` is a `tuple[dict]`, the `scalers` key will be extracted
before being passed to `get_loss_function`, and the `scalers` defined
in each loss only applied to the respective loss. Thereby `scalers`
added to this class will be routed correctly.
If `losses` is a `tuple[Callable]`, all `scalars` added to this class
If `losses` is a `tuple[Callable]`, all `scalers` added to this class
will be added to all underlying losses.
And if `losses` is a `tuple[BaseLoss]`, no scalars added to
And if `losses` is a `tuple[BaseLoss]`, no scalers added to
this class will be added to the underlying losses, as it is
assumed that will be done by the parent function.
Expand All @@ -60,12 +60,12 @@ def __init__(
if a `tuple[dict]`:
Tuple of losses to initialise with `get_loss_function`.
Allows for kwargs to be passed, and weighings controlled.
If a loss should only have some of the scalars, set `scalars` in the loss config.
If no scalars are set, all scalars added to this class will be included.
If a loss should only have some of the scalers, set `scalers` in the loss config.
If no scalers are set, all scalers added to this class will be included.
if a `tuple[Callable]`:
Will be called with `kwargs`, and all scalars added to this class added.
Will be called with `kwargs`, and all scalers added to this class added.
if a `tuple[BaseLoss]`:
Added to the loss function, and no scalars passed through.
Added to the loss function, and no scalers passed through.
*extra_losses: dict[str, Any] | Callable | BaseLoss],
Additional arg form of losses to include in the combined loss.
loss_weights : optional, tuple[int, ...] | None
Expand All @@ -82,8 +82,8 @@ def __init__(
{"__target__": "anemoi.training.losses.MSELoss"},
loss_weights=(1.0,),
)
CombinedLoss.add_scaler(name = 'scalar_1', ...)
# Only added to the `MSELoss` if specified in it's `scalars`.
CombinedLoss.add_scaler(name = 'scaler_1', ...)
# Only added to the `MSELoss` if specified in it's `scalers`.
--------
>>> CombinedLoss(
losses = [anemoi.training.losses.MSELoss],
Expand All @@ -97,9 +97,9 @@ def __init__(
losses:
- _target_: anemoi.training.losses.MSELoss
- _target_: anemoi.training.losses.MAELoss
scalars: ['*']
scalers: ['*']
loss_weights: [1.0, 0.6]
# All scalars passed to this class will be added to each underlying loss
# All scalers passed to this class will be added to each underlying loss
```
```
Expand All @@ -117,7 +117,7 @@ def __init__(
super().__init__(node_weights=None)

self.losses: list[BaseLoss] = []
self._loss_scalar_specification: dict[int, list[str]] = {}
self._loss_scaler_specification: dict[int, list[str]] = {}

losses = (*(losses or []), *extra_losses)
if loss_weights is None:
Expand All @@ -129,13 +129,13 @@ def __init__(
for i, loss in enumerate(losses):

if isinstance(loss, (DictConfig, dict)):
self._loss_scalar_specification[i] = loss.pop("scalars", ["*"])
self.losses.append(GraphForecaster.get_loss_function(loss, scalars={}, **dict(kwargs)))
self._loss_scaler_specification[i] = loss.pop("scalers", ["*"])
self.losses.append(get_loss_function(loss, scalers={}, **dict(kwargs)))
elif isinstance(loss, Callable):
self._loss_scalar_specification[i] = ["*"]
self._loss_scaler_specification[i] = ["*"]
self.losses.append(loss(**kwargs))
else:
self._loss_scalar_specification[i] = []
self._loss_scaler_specification[i] = []
self.losses.append(loss)

self.add_module(self.losses[-1].name + str(i), self.losses[-1])
Expand Down Expand Up @@ -177,30 +177,30 @@ def name(self) -> str:
return "combined_" + "_".join(getattr(loss, "name", loss.__class__.__name__.lower()) for loss in self.losses)

@property
def scalar(self) -> ScaleTensor:
"""Get union of underlying scalars."""
scalars = {}
def scaler(self) -> ScaleTensor:
"""Get union of underlying scalers."""
scalers = {}
for loss in self.losses:
scalars.update(loss.scalar.tensors)
return ScaleTensor(scalars)

@scalar.setter
def scalar(self, _: Any) -> None:
"""Set underlying loss scalars."""
if not self._initial_set_scalar: # Allow parent class to 'initialise' the scalar
self._initial_set_scalar = True
scalers.update(loss.scaler.tensors)
return ScaleTensor(scalers)

@scaler.setter
def scaler(self, _: Any) -> None:
"""Set underlying loss scalers."""
if not self._initial_set_scaler: # Allow parent class to 'initialise' the scaler
self._initial_set_scaler = True
return
excep_msg = "Cannot set `CombinedLoss` scalar directly, use `add_scalar` or `update_scalar`."
excep_msg = "Cannot set `CombinedLoss` scaler directly, use `add_scaler` or `update_scaler`."
raise AttributeError(excep_msg)

@functools.wraps(ScaleTensor.add_scalar, assigned=("__doc__", "__annotations__"))
def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None:
for i, spec in self._loss_scalar_specification.items():
@functools.wraps(ScaleTensor.add_scaler, assigned=("__doc__", "__annotations__"))
def add_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor, *, name: str | None = None) -> None:
for i, spec in self._loss_scaler_specification.items():
if "*" in spec or name in spec:
self.losses[i].scalar.add_scalar(dimension, scalar, name=name)
self.losses[i].scaler.add_scaler(dimension, scaler, name=name)

@functools.wraps(ScaleTensor.update_scalar, assigned=("__doc__", "__annotations__"))
def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None:
for i, spec in self._loss_scalar_specification.items():
@functools.wraps(ScaleTensor.update_scaler, assigned=("__doc__", "__annotations__"))
def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None:
for i, spec in self._loss_scaler_specification.items():
if "*" in spec or name in spec:
self.losses[i].scalar.update_scalar(name, scalar=scalar, override=override)
self.losses[i].scaler.update_scaler(name, scaler=scaler, override=override)

0 comments on commit 4f55c4f

Please sign in to comment.