From 9abad18c87b88230f8a4c6bce41cce48dfa91e01 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Mon, 27 Jan 2025 14:11:58 +0000 Subject: [PATCH 1/5] multibackend alltoall wrapper --- .../anemoi/models/distributed/transformer.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/models/src/anemoi/models/distributed/transformer.py b/models/src/anemoi/models/distributed/transformer.py index 78691bba..52c3e758 100644 --- a/models/src/anemoi/models/distributed/transformer.py +++ b/models/src/anemoi/models/distributed/transformer.py @@ -18,6 +18,31 @@ from anemoi.models.distributed.utils import get_memory_format +def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup): + """ + Wrapper function for all_to_all across NCCL, MPI and Gloo backends. + There is no all_to_all primitive for the Gloo backend. In that case each + process broadcasts its tensor asynchronously. + + Retuns nothing but modifies output_list in-place + + """ + comm_size = dist.get_world_size(group=group) + + if dist.is_mpi_available() or dist.is_nccl_available(): + dist.all_to_all(output_list, input_list, group=group) + else: + reqs = [] + for src in range(0, comm_size): + if src == dist.get_rank(group=group): + output_list[src] = input_list[0] + req = dist.broadcast(output_list[src], src, group=group, async_op=True) + reqs.append(req) + + for req in reqs: + req.wait() + + def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: """Apply all_to_all along the head dimension. @@ -50,7 +75,7 @@ def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = for rank in range(comm_size) ] - dist.all_to_all(output_list, input_list, group=group) + _alltoallwrapper(output_list, input_list, group=group) # Note: torch.cat already creates a contiguous tensor. return torch.cat(output_list, dim=-2).contiguous(memory_format=input_format) @@ -76,7 +101,7 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N output_list = [torch.empty_like(input_list[comm_rank]) for _ in range(comm_size)] - dist.all_to_all(output_list, input_list, group=group) + _alltoallwrapper(output_list, input_list, group=group) # Note: torch.cat already creates a contiguous tensor. return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format) From 5096bbcba2cd5e5eb3447020e7f4862171878156 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 28 Jan 2025 12:03:42 +0000 Subject: [PATCH 2/5] change how backend is selected --- models/src/anemoi/models/distributed/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/src/anemoi/models/distributed/transformer.py b/models/src/anemoi/models/distributed/transformer.py index 52c3e758..7ff0362d 100644 --- a/models/src/anemoi/models/distributed/transformer.py +++ b/models/src/anemoi/models/distributed/transformer.py @@ -29,9 +29,7 @@ def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup): """ comm_size = dist.get_world_size(group=group) - if dist.is_mpi_available() or dist.is_nccl_available(): - dist.all_to_all(output_list, input_list, group=group) - else: + if dist.get_backend(group) == "gloo": reqs = [] for src in range(0, comm_size): if src == dist.get_rank(group=group): @@ -41,6 +39,8 @@ def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup): for req in reqs: req.wait() + else: + dist.all_to_all(output_list, input_list, group=group) def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: From 30ea60e765571f35dbdb25f33c1c4c79db139def Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 28 Jan 2025 12:10:54 +0000 Subject: [PATCH 3/5] update changelog --- models/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/models/CHANGELOG.md b/models/CHANGELOG.md index 7fb9bba8..b90a2ba3 100644 --- a/models/CHANGELOG.md +++ b/models/CHANGELOG.md @@ -20,6 +20,7 @@ Keep it human-readable, your future self will thank you! - Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88) - Add Normalized Relu Bounding for minimum bounding thresholds different than 0 [#64](https://github.com/ecmwf/anemoi-core/pull/64) - 'predict\_step' can take an optional model comm group. [#77](https://github.com/ecmwf/anemoi-core/pull/77) +- Added model parallel support for the Transformer model when running on multiple CPU devices. [#95](https://github.com/ecmwf/anemoi-core/pull/95) ## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design From abc5bcf712c5dd6da8d3ff0bac75221e8bbef7de Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 4 Feb 2025 15:20:51 +0100 Subject: [PATCH 4/5] fixed alltoall alg --- .../anemoi/models/distributed/transformer.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/models/src/anemoi/models/distributed/transformer.py b/models/src/anemoi/models/distributed/transformer.py index 7ff0362d..e13869ed 100644 --- a/models/src/anemoi/models/distributed/transformer.py +++ b/models/src/anemoi/models/distributed/transformer.py @@ -30,13 +30,26 @@ def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup): comm_size = dist.get_world_size(group=group) if dist.get_backend(group) == "gloo": - reqs = [] - for src in range(0, comm_size): - if src == dist.get_rank(group=group): - output_list[src] = input_list[0] - req = dist.broadcast(output_list[src], src, group=group, async_op=True) - reqs.append(req) + # Need to check torch version here bc the syntax for dist.send/recv changed in torch v2.6 + torch_version = torch.__version__.split(".") + torch_major_version = int(torch_version[0]) + torch_minor_version = int(torch_version[1]) + print(f"{torch_version=}") + if torch_major_version <= 2 and torch_minor_version < 6: + raise NotImplementedError("Gloo all_to_all not implemented for torch < v2.6") + + reqs = [] + rank = dist.get_rank(group=group) + # Here we implement the linear shift algorithm from Hofmann and Ruenger, 2013 + for i in range(0, comm_size): + j = (i - rank + comm_size) % comm_size + if j != rank: + # exchange data with rank j + reqs.append(dist.isend(input_list[j], group_dst=j, group=group)) + reqs.append(dist.irecv(output_list[j], group_src=j, group=group)) + else: + output_list[rank] = input_list[rank] for req in reqs: req.wait() else: From 55de6a599a237ad0ee683c6958a50d0b8fa6a7ae Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 4 Feb 2025 15:21:47 +0100 Subject: [PATCH 5/5] forgot to remove print --- models/src/anemoi/models/distributed/transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/src/anemoi/models/distributed/transformer.py b/models/src/anemoi/models/distributed/transformer.py index e13869ed..bbdd6b98 100644 --- a/models/src/anemoi/models/distributed/transformer.py +++ b/models/src/anemoi/models/distributed/transformer.py @@ -35,7 +35,6 @@ def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup): torch_version = torch.__version__.split(".") torch_major_version = int(torch_version[0]) torch_minor_version = int(torch_version[1]) - print(f"{torch_version=}") if torch_major_version <= 2 and torch_minor_version < 6: raise NotImplementedError("Gloo all_to_all not implemented for torch < v2.6")