Skip to content

Commit

Permalink
feature(models): Add model comm group to predict_step (#77)
Browse files Browse the repository at this point in the history
* added optional model comm group to predict_step

* update changelog

* added **kwargs to predict step
  • Loading branch information
cathalobrien authored Jan 17, 2025
1 parent b0b69c9 commit db587fe
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions models/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Keep it human-readable, your future self will thank you!
- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84)
- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97)
- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88)
- 'predict\_step' can take an optional model comm group. [#77](https://github.com/ecmwf/anemoi-core/pull/77)

## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design

Expand Down
8 changes: 6 additions & 2 deletions models/src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# nor does it submit to any jurisdiction.

import uuid
from typing import Optional

import torch
from hydra.utils import instantiate
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData

from anemoi.models.preprocessing import Processors
Expand Down Expand Up @@ -94,7 +96,9 @@ def _build_model(self) -> None:
# Use the forward method of the model directly
self.forward = self.model.forward

def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
def predict_step(
self, batch: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None, **kwargs
) -> torch.Tensor:
"""Prediction step for the model.
Parameters
Expand All @@ -118,6 +122,6 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
# batch, timesteps, horizonal space, variables
x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index

y_hat = self(x)
y_hat = self(x, model_comm_group)

return self.post_processors(y_hat, in_place=False)

0 comments on commit db587fe

Please sign in to comment.