Skip to content

Commit

Permalink
et_replay refactoring: removed hack code for FBGEMM, introduced skip-…
Browse files Browse the repository at this point in the history
…node option, MAST launcher update (#181)

Summary:
Pull Request resolved: #181

This DIFF include the following et_replay refactoring:

1. Cleaned up the hack code for FBGEMM
    The current implementation for FBGEMM related ops relied on guessing the parameters of the original FBGEMM module based on its forward and backward function calls. Since FBGEMM keeps involving, most of the code are outdated. It is not a sustainable way to support it.
    The most important reason that these ops can not be replayed is it usually has index input tensors. If random data is used for these integer tensors, it usually runs into illegal memory issue.
    To fix this issue, another DIFF (https://www.internalfb.com/diff/D62889784) is going to capture index tensor based on user's selection. Then in replay, the index tensor is loaded for FBGEMM ops. It has been proved in ICVR model, with the index tensor data, we can replay all of FBGEMM ops.

2. Introduced new options --skip-node-file and --update-skip-node-file:
    If --skip-node-file is available, the json file that defines the nodes to skip will be loaded to skip the ops, --update-skip-node-file is a special run mode, it will go through all compute ops, if an op fails to run, the skip-node-file will be updated to include the failed op.

3. MAST launcher has been updated to create a new FBPKG for et_replay

4. The DFS traverser to collect the ops has been simplified. If a node is an operator, the children of that node will be ignored. The only exception is c10:: related ops since record_param_comms is a child of c10:: op, and comm_replay only uses record_param_comms

5. generate_io_tensor for CommsReplayManager in et_replay.py has been removed temporarily, it does not handle all collectives correctly for creating input/output tensors. The current version uses comms_replay to allocate the tensors. We can put it back when that function is ready.

6. Some other minor fixes, for example, use logger instead of print

Reviewed By: briancoutinho

Differential Revision: D61055957

fbshipit-source-id: 34ca74b221b3525b4e2a81df59df60f8924253c5
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Oct 11, 2024
1 parent e196340 commit 1ac7959
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 436 deletions.
28 changes: 9 additions & 19 deletions et_replay/et_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"float": (torch.float32, torch.randn),
"double": (torch.float64, torch.randn),
"signed char": (torch.int8, torch.ones),
"unsigned char": (torch.int8, torch.ones),
"unsigned char": (torch.uint8, torch.ones),
"c10::Half": (torch.half, torch.ones),
"c10::BFloat16": (torch.bfloat16, torch.ones),
}
Expand All @@ -38,7 +38,7 @@
"float": ("torch.float32", "torch.randn"),
"double": ("torch.float64", "torch.randn"),
"signed char": ("torch.int8", "torch.ones"),
"unsigned char": ("torch.int8", "torch.ones"),
"unsigned char": ("torch.uint8", "torch.ones"),
"c10::Half": ("torch.half", "torch.ones"),
"c10::BFloat16": ("torch.bfloat16", "torch.ones"),
}
Expand All @@ -59,15 +59,17 @@
}


def is_tensor_list(n, idx):
return isinstance(idx, int) and "GenericList[Tensor" in n.input_types[idx]
def is_tensor_list(n, idx, is_input):
types_list = n.input_types if is_input else n.output_types
return isinstance(idx, int) and "GenericList[Tensor" in types_list[idx]


def is_tensor(n, idx):
def is_tensor(n, idx, is_input):
types_list = n.input_types if is_input else n.output_types
return (
isinstance(idx, int)
and "Tensor" in n.input_types[idx]
and "GenericList" not in n.input_types[idx]
and "Tensor" in types_list[idx]
and "GenericList" not in types_list[idx]
)


Expand Down Expand Up @@ -166,22 +168,10 @@ def is_qualified(op):


def get_input_tensors(n):
if is_fbgemm_forward(n):
idx_list = fbgemm_input_args_indices(n)
return zip(
[n.input_types[x] for x in idx_list],
[
tuple(n.inputs[x]) if isinstance(n.inputs[x], list) else n.inputs[x]
for x in idx_list
],
[n.input_shapes[x] for x in idx_list],
)
return n.get_input_tensors()


def get_output_tensors(n):
if is_fbgemm_forward(n):
return list(zip(n.output_types, [tuple(x) for x in n.outputs], n.output_shapes))
return n.get_output_tensors()


Expand Down
50 changes: 11 additions & 39 deletions et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,6 @@ class NodeType(Enum):
LABEL = 2


# Label markers
LABEL_MARKERS = [
"##",
"__",
"module::",
"DLRM ",
"DistributedDataParallel",
"Profiler",
"[pytorch|",
"forward",
"backward",
"Optimizer.zero_grad",
"[param",
"<forward op>",
"reduce-grads",
"multiply-grads",
"clip-grads",
"optimizer",
"gans_torchscript_ops::",
"split_with_sizes",
"chunk",
"All2All_Pooled_ReqBackward",
"All2All_Pooled_Req",
"All2All_Pooled_Wait",
"c10d::",
"TorchDynamo Cache Lookup",
"CompiledFunction",
"Torch-Compiled Region",
]


"""
TensorNode
Expand Down Expand Up @@ -180,7 +149,7 @@ def __init__(
self.fw_parent_id: int = fw_parent_id
self.seq_id: int = seq_id
self.scope: int = scope
self.type: NodeType = self.detect_type(name, inputs, outputs)
self.type: NodeType = self.detect_type()
# self.inputs: List[Any] = [tuple(i) if isinstance(i, list) else i for i in inputs]
self.inputs: List[Any] = inputs
self.input_types: List[str] = input_types
Expand Down Expand Up @@ -280,16 +249,19 @@ def get_parent_by_name(self, names) -> Optional[Node]:
return node
return None

def detect_type(self, name: str, inputs: List[Any], outputs: List[Any]) -> NodeType:
def detect_type(self) -> NodeType:
if (
any(name.startswith(x) for x in LABEL_MARKERS)
# and not outputs
# for collectives, ET records both c10d::collective_function and
# record_param_comms. Only record_param_comms is used for replay
self.name == "record_param_comms"
or (
self.op_schema != "" and not self.name.startswith("c10d::")
) # for aten ops
or self.kernel_backend == "triton" # for PT2 triton kernels
):
# if outputs:
# print(f"{name} has outputs, not expected.")
return NodeType.LABEL
else:
return NodeType.OPERATOR
else:
return NodeType.LABEL

def get_tensors(self, param_list: Iterable) -> List[tuple]:
tensors = []
Expand Down
4 changes: 2 additions & 2 deletions et_replay/tests/test_execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_trace_load_resnet_1gpu_ptorch_1_0_3(self):
)
t, et = self._test_and_validate_trace(et_file)
self.assertGreater(t.num_ops(), 1000)
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_comm_ops(), 27)
self.assertEqual(t.num_triton_ops(), 0)

def test_trace_load_resnet_2gpu_ptorch_1_1_0(self):
Expand All @@ -39,7 +39,7 @@ def test_trace_load_resnet_2gpu_ptorch_1_1_0(self):
)
t, et = self._test_and_validate_trace(et_file)
self.assertGreater(t.num_ops(), 1000)
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_comm_ops(), 27)
self.assertEqual(t.num_triton_ops(), 0)


Expand Down
3 changes: 2 additions & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
pass

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# sleep for 20ms to wait for next collective
LOOP_TIMER_S = 0.02
Expand Down Expand Up @@ -640,7 +641,7 @@ def prepComms(
curComm: commsArgs,
commsParams: commsParamsHolderBase,
regenerateTensors: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor]]:
"""
Prepares the appropriate tensors for the current collective communication.
Expand Down
Loading

0 comments on commit 1ac7959

Please sign in to comment.