Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mean squared log error #158

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions training/docs/modules/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The following loss functions are available by default:
- ``WeightedHuberLoss``: Latitude-weighted Huber loss.
- ``WeightedLogCoshLoss``: Latitude-weighted log-cosh loss.
- ``WeightedRMSELoss``: Latitude-weighted root-mean-squared-error.
- ``WeightedMSLELoss``: Latitude-weighted mean-squared-log-error.
- ``CombinedLoss``: Combined component weighted loss.

These are available in the ``anemoi.training.losses`` module, at
Expand Down
80 changes: 80 additions & 0 deletions training/src/anemoi/training/losses/msle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from __future__ import annotations

import logging

import torch

from anemoi.training.losses.weightedloss import BaseWeightedLoss

LOGGER = logging.getLogger(__name__)


class WeightedMSLELoss(BaseWeightedLoss):
"""Node-weighted MSE loss."""

name = "wmse"

def __init__(
self,
node_weights: torch.Tensor,
ignore_nans: bool = False,
**kwargs,
) -> None:
"""Node- and feature weighted MSLE Loss.

Parameters
----------
node_weights : torch.Tensor of shape (N, )
Weight of each node in the loss function
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__(
node_weights=node_weights,
ignore_nans=ignore_nans,
**kwargs,
)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
"""Calculates the lat-weighted MSLE loss.

Parameters
----------
pred : torch.Tensor
Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs)
target : torch.Tensor
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
scalar_indices: tuple[int,...], optional
Indices to subset the calculated scalar with, by default None
without_scalars: list[str] | list[int] | None, optional
list of scalars to exclude from scaling. Can be list of names or dimensions to exclude.
By default None

Returns
-------
torch.Tensor
Weighted MSLE loss
"""
out = torch.square(torch.log(pred + 1) - torch.log(target + 1))
out = self.scale(out, scalar_indices, without_scalars=without_scalars)
return self.scale_by_node_weights(out, squash)
1 change: 1 addition & 0 deletions training/src/anemoi/training/schemas/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ImplementedLossesUsingBaseLossSchema(str, Enum):
mae = "anemoi.training.losses.mae.WeightedMAELoss"
logcosh = "anemoi.training.losses.logcosh.WeightedLogCoshLoss"
huber = "anemoi.training.losses.huber.WeightedHuberLoss"
msle = "anemoi.training.losses.msle.WeightedMSLELoss"


class BaseLossSchema(BaseModel):
Expand Down
30 changes: 30 additions & 0 deletions training/tests/train/test_msle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch
from torch.testing import assert_close

from anemoi.training.losses.msle import WeightedMSLELoss


@pytest.fixture
def basic_inputs() -> tuple:
pred = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
node_weights = torch.ones(2)
return pred, target, node_weights


def test_forward(basic_inputs: tuple) -> None:
pred, target, node_weights = basic_inputs
loss = WeightedMSLELoss(node_weights=node_weights)
computed_loss = loss(pred, target)
assert isinstance(computed_loss, torch.Tensor)
assert_close(computed_loss, torch.tensor(0.0))
Loading