Skip to content

Commit

Permalink
Fixed the bug in process group initialization
Browse files Browse the repository at this point in the history
Summary:
torch.distributed.new_group required requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.
https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#new_group

The current implementation requires the process group id is unique across all ranks. However, that is not the case. For example, in llama4:

Rank 0
 
 {\"pg_name\": \"0\", \"pg_desc\": \"default_pg\", \"backend_config\": \"cuda:nccl\", \"ranks\": [], \"group_size\": 16, \"group_count\": 5}, 
 {\"pg_name\": \"1\", \"pg_desc\": \"DP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [0, 8], \"group_size\": 2, \"group_count\": 5}, 
 {\"pg_name\": \"2\", \"pg_desc\": \"MP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [0, 1, 2, 3, 4, 5, 6, 7], \"group_size\": 8, \"group_count\": 5}, 
 {\"pg_name\": \"3\", \"pg_desc\": \"TP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [0, 1, 2, 3, 4, 5, 6, 7], \"group_size\": 8, \"group_count\": 5}, 
 {\"pg_name\": \"4\", \"pg_desc\": \"PP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [0], \"group_size\": 1, \"group_count\": 5}]"}


 Rank 8

{\"pg_name\": \"0\", \"pg_desc\": \"default_pg\", \"backend_config\": \"cuda:nccl\", \"ranks\": [], \"group_size\": 16, \"group_count\": 5}, 
{\"pg_name\": \"1\", \"pg_desc\": \"DP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [0, 8], \"group_size\": 2, \"group_count\": 5},
{\"pg_name\": \"2\", \"pg_desc\": \"MP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [8, 9, 10, 11, 12, 13, 14, 15], \"group_size\": 8, \"group_count\": 5}, 
{\"pg_name\": \"3\", \"pg_desc\": \"TP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [8, 9, 10, 11, 12, 13, 14, 15], \"group_size\": 8, \"group_count\": 5},
{\"pg_name\": \"4\", \"pg_desc\": \"PP\", \"backend_config\": \"cuda:nccl\", \"ranks\": [8], \"group_size\": 1, \"group_count\": 5}]"}

You can see for pg_id = 1, it ranks are different. 

This DIFF is to fix this issue by using group rank ids as a key. For every unique group rank id list, a new process group is created. The idea behind it is if the sorted group rank list is the same, it is the same process group. 

After the process group is created, the process group id in the ET file of the current rank is used to map pg id to the process group. pg id from all other ranks are set to -1, since it is not used to run collectives.

Differential Revision: D64345603
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Oct 14, 2024
1 parent 1ac7959 commit 3e77b96
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
43 changes: 32 additions & 11 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import json
import logging
import os
from collections import defaultdict

from itertools import cycle
from time import sleep
from typing import List, Optional
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -37,6 +38,7 @@
has_ext_dist = False

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _downcast(input, bitwidth):
Expand Down Expand Up @@ -1083,28 +1085,45 @@ def initialize_groups(self, backend="gloo"):
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# map from group_rank to pgId, pgId of the groups in current rank is the pgId defined in
# ET, pfId of the groups from other ranks is -1.
group_rank_to_pgId: Dict[Tuple[int], List[int]] = defaultdict(set)
for pg_id, group_ranks in self.commsParams.groupRanks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
group_ranks.sort()
rank_tuple = tuple(group_ranks)
if rank_tuple in group_rank_to_pgId:
pg_ids = group_rank_to_pgId[rank_tuple].append(pg_id)
else:
group_rank_to_pgId[rank_tuple] = [pg_id]

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() functions requires that all processes in the main group enter,
# because new_group() function requires that all processes in the default group call it,
# even if they are not going to be members of the group.
# Assumption: pg_name is unique and consistent for all ranks
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
group_ranks_sync = self.commsParams.groupRanks.copy()
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))

# convert pg_id in json_data to int
pg_id2group_ranks = {
pg_id_to_group_ranks = {
int(pg_id): rank for pg_id, rank in json.loads(json_data).items()
}

group_ranks_sync.update(pg_id2group_ranks)
for _, group_ranks in pg_id_to_group_ranks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
group_ranks.sort()
rank_tuple = tuple(group_ranks)
if rank_tuple not in group_rank_to_pgId:
group_rank_to_pgId[rank_tuple] = [-1]

# create additional groups
for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items():
# create additional groups, sort it to make sure pg are created in the same order for all ranks
for group_ranks, pg_ids in dict(sorted(group_rank_to_pgId.items())).items():
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
Expand All @@ -1115,11 +1134,13 @@ def initialize_groups(self, backend="gloo"):
): # this is the default group, it has already been created
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend)
pg = self.get_new_pg(group_ranks=list(group_ranks), backend=backend)
logger.info(
f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}"
f"initialized_group: create new group, pg_ids = {pg_ids}, group_ranks = {group_ranks}"
)
groups[pg_id] = pg
for pg_id in pg_ids:
if pg_id != -1:
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
if len(groups):
Expand Down
1 change: 1 addition & 0 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ def replaySingle(
and curComm.dst_rank != self.backendFuncs.get_global_rank()
)
):
logger.info(f"Skip collective {collName} id = {curComm.id}")
return

if groupRank >= 0:
Expand Down

0 comments on commit 3e77b96

Please sign in to comment.