Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 21, 2025
1 parent 0834492 commit 61bc20f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions training/docs/modules/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ By default, only `all` is kept in the normalised space and scaled.

Additionally, you can define your own loss function by subclassing
``BaseLoss`` and implementing the ``forward`` method, or by subclassing
``BaseLoss`` and implementing the ``calculate_difference``
function. The latter abstracts the scaling, and node weighting, and
allows you to just specify the difference calculation.
``BaseLoss`` and implementing the ``calculate_difference`` function. The
latter abstracts the scaling, and node weighting, and allows you to just
specify the difference calculation.

.. code:: python
Expand Down
4 changes: 2 additions & 2 deletions training/src/anemoi/training/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def scale(

scaler = scaler.expand_as(x)
return x[subset_indices] * scaler[subset_indices]

def reduce(self, out: torch.Tensor, squash: bool = True) -> torch.Tensor:
if squash:
out = self.avg_function(out, dim=-1)

return self.sum_function(out, dim=(0, 1, 2))

def forward(
self,
pred: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion training/src/anemoi/training/losses/scalers/base_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def normalise(self, values: np.ndarray) -> np.ndarray:

def get_scaling(self) -> SCALER_DTYPE:
"""Get scaler.
Returns
-------
scale_dims : tuple[int]
Expand Down

0 comments on commit 61bc20f

Please sign in to comment.