Skip to content

Commit

Permalink
fix truncated ranks string parsing in profiler trace analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
sanshang-nv committed Jan 22, 2025
1 parent 7af31e6 commit 57f619e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions et_replay/comm/profiler_trace_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def _get_dict_value(d, k, err_msg):
raise ValueError(err_msg)
return d.get(k)

def _parse_ranks(ranks_str: str, ranks_count: int):
ranks = ast.literal_eval(ranks_str)
if ranks[-1] is Ellipsis:
# https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/profiler_kineto.cpp#L508
# the first param of the struct `SaveNcclMetaConfig` controls the truncation
# if ranks string is too long, it will be truncated in kineto trace.
# (e.g., "[64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
# 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, ...]"),
stride = ranks[1] - ranks[0]
ranks = [ranks[0] + i * stride for i in range(ranks_count)]
return ranks

def _calculate_event_data_size(evt):
return max(evt['args']['In msg nelems'], evt['args']['Out msg nelems']) * _dtype_size_map[evt['args']['dtype']]

Expand Down Expand Up @@ -155,7 +167,7 @@ def calculate_sbw(trace_data):

def pick_iter_e2e_time_(trace_data, tl):
tl.extend([evt['dur'] for evt in trace_data['traceEvents'] if evt.get('cat', '') == 'user_annotation' and evt['name'].startswith('ProfilerStep#')])

def pick_comm_bw_(trace_data, comm_bw_data):
rank = trace_data['distributedInfo']['rank']
nccl_events = [i for i in trace_data['traceEvents'] if i.get('cat', '') == 'kernel' \
Expand All @@ -166,7 +178,7 @@ def pick_comm_bw_(trace_data, comm_bw_data):
data_size = _calculate_event_data_size(evt)
ranks_count = evt['args']['Group size']

ranks = ast.literal_eval(evt['args']['Process Group Ranks'])
ranks = _parse_ranks(evt['args']['Process Group Ranks'], ranks_count)
pg_id = int(evt['args']['Process Group Name'])
pg = tuple([*ranks, pg_id]) if rank == min(ranks) else None

Expand Down

0 comments on commit 57f619e

Please sign in to comment.