Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): multibackend all_to_all wrapper #95

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Keep it human-readable, your future self will thank you!
- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88)
- Add Normalized Relu Bounding for minimum bounding thresholds different than 0 [#64](https://github.com/ecmwf/anemoi-core/pull/64)
- 'predict\_step' can take an optional model comm group. [#77](https://github.com/ecmwf/anemoi-core/pull/77)
- Added model parallel support for the Transformer model when running on multiple CPU devices. [#95](https://github.com/ecmwf/anemoi-core/pull/95)

## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design

Expand Down
41 changes: 39 additions & 2 deletions models/src/anemoi/models/distributed/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,43 @@
from anemoi.models.distributed.utils import get_memory_format


def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup):
"""
Wrapper function for all_to_all across NCCL, MPI and Gloo backends.
There is no all_to_all primitive for the Gloo backend. In that case each
process broadcasts its tensor asynchronously.

Retuns nothing but modifies output_list in-place

"""
comm_size = dist.get_world_size(group=group)

if dist.get_backend(group) == "gloo":

# Need to check torch version here bc the syntax for dist.send/recv changed in torch v2.6
torch_version = torch.__version__.split(".")
torch_major_version = int(torch_version[0])
torch_minor_version = int(torch_version[1])
if torch_major_version <= 2 and torch_minor_version < 6:
raise NotImplementedError("Gloo all_to_all not implemented for torch < v2.6")

reqs = []
rank = dist.get_rank(group=group)
# Here we implement the linear shift algorithm from Hofmann and Ruenger, 2013
for i in range(0, comm_size):
j = (i - rank + comm_size) % comm_size
if j != rank:
# exchange data with rank j
reqs.append(dist.isend(input_list[j], group_dst=j, group=group))
reqs.append(dist.irecv(output_list[j], group_src=j, group=group))
else:
output_list[rank] = input_list[rank]
for req in reqs:
req.wait()
else:
dist.all_to_all(output_list, input_list, group=group)


def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor:
"""Apply all_to_all along the head dimension.

Expand Down Expand Up @@ -50,7 +87,7 @@ def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] =
for rank in range(comm_size)
]

dist.all_to_all(output_list, input_list, group=group)
_alltoallwrapper(output_list, input_list, group=group)

# Note: torch.cat already creates a contiguous tensor.
return torch.cat(output_list, dim=-2).contiguous(memory_format=input_format)
Expand All @@ -76,7 +113,7 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N

output_list = [torch.empty_like(input_list[comm_rank]) for _ in range(comm_size)]

dist.all_to_all(output_list, input_list, group=group)
_alltoallwrapper(output_list, input_list, group=group)

# Note: torch.cat already creates a contiguous tensor.
return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format)
Expand Down