diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 5f5d6b97..77049289 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -6,10 +6,11 @@ import json import logging import os +from collections import defaultdict from itertools import cycle from time import sleep -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -37,6 +38,7 @@ has_ext_dist = False logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def _downcast(input, bitwidth): @@ -1083,28 +1085,45 @@ def initialize_groups(self, backend="gloo"): world_size = self.get_world_size() global_rank = self.get_global_rank() + # map from group_rank to pgId, pgId of the groups in current rank is the pgId defined in + # ET, pfId of the groups from other ranks is -1. + group_rank_to_pgId: Dict[Tuple[int], List[int]] = defaultdict(set) + for pg_id, group_ranks in self.commsParams.groupRanks.items(): + if group_ranks is None or len(group_ranks) == 0: + group_ranks = list(range(world_size)) + group_ranks.sort() + rank_tuple = tuple(group_ranks) + if rank_tuple in group_rank_to_pgId: + pg_ids = group_rank_to_pgId[rank_tuple].append(pg_id) + else: + group_rank_to_pgId[rank_tuple] = [pg_id] + # sync pgs across ranks to fix hang with multiple comm groups - # because new_group() functions requires that all processes in the main group enter, + # because new_group() function requires that all processes in the default group call it, # even if they are not going to be members of the group. - # Assumption: pg_name is unique and consistent for all ranks sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store) sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks)) torch.distributed.barrier() - group_ranks_sync = self.commsParams.groupRanks.copy() for i in range(self.get_world_size()): if i == global_rank: continue json_data = sync_store.get(str(i)) # convert pg_id in json_data to int - pg_id2group_ranks = { + pg_id_to_group_ranks = { int(pg_id): rank for pg_id, rank in json.loads(json_data).items() } - group_ranks_sync.update(pg_id2group_ranks) + for _, group_ranks in pg_id_to_group_ranks.items(): + if group_ranks is None or len(group_ranks) == 0: + group_ranks = list(range(world_size)) + group_ranks.sort() + rank_tuple = tuple(group_ranks) + if rank_tuple not in group_rank_to_pgId: + group_rank_to_pgId[rank_tuple] = [-1] - # create additional groups - for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items(): + # create additional groups, sort it to make sure pg are created in the same order for all ranks + for group_ranks, pg_ids in dict(sorted(group_rank_to_pgId.items())).items(): if ( len(group_ranks) > world_size ): # this means that --auto-shrink is enabled, only use default pg @@ -1115,11 +1134,13 @@ def initialize_groups(self, backend="gloo"): ): # this is the default group, it has already been created pg = self.get_default_group() else: - pg = self.get_new_pg(group_ranks=group_ranks, backend=backend) + pg = self.get_new_pg(group_ranks=list(group_ranks), backend=backend) logger.info( - f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}" + f"initialized_group: create new group, pg_ids = {pg_ids}, group_ranks = {group_ranks}" ) - groups[pg_id] = pg + for pg_id in pg_ids: + if pg_id != -1: + groups[pg_id] = pg # if additional groups are created, overwrite the default groups list if len(groups): diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 5706b6f9..5eff3e37 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -1057,6 +1057,7 @@ def replaySingle( and curComm.dst_rank != self.backendFuncs.get_global_rank() ) ): + logger.info(f"Skip collective {collName} id = {curComm.id}") return if groupRank >= 0: