From 1ac795966d230d8d3ec64e407d039e159d86c3a6 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Thu, 10 Oct 2024 22:37:00 -0700 Subject: [PATCH] et_replay refactoring: removed hack code for FBGEMM, introduced skip-node option, MAST launcher update (#181) Summary: Pull Request resolved: https://github.com/facebookresearch/param/pull/181 This DIFF include the following et_replay refactoring: 1. Cleaned up the hack code for FBGEMM The current implementation for FBGEMM related ops relied on guessing the parameters of the original FBGEMM module based on its forward and backward function calls. Since FBGEMM keeps involving, most of the code are outdated. It is not a sustainable way to support it. The most important reason that these ops can not be replayed is it usually has index input tensors. If random data is used for these integer tensors, it usually runs into illegal memory issue. To fix this issue, another DIFF (https://www.internalfb.com/diff/D62889784) is going to capture index tensor based on user's selection. Then in replay, the index tensor is loaded for FBGEMM ops. It has been proved in ICVR model, with the index tensor data, we can replay all of FBGEMM ops. 2. Introduced new options --skip-node-file and --update-skip-node-file: If --skip-node-file is available, the json file that defines the nodes to skip will be loaded to skip the ops, --update-skip-node-file is a special run mode, it will go through all compute ops, if an op fails to run, the skip-node-file will be updated to include the failed op. 3. MAST launcher has been updated to create a new FBPKG for et_replay 4. The DFS traverser to collect the ops has been simplified. If a node is an operator, the children of that node will be ignored. The only exception is c10:: related ops since record_param_comms is a child of c10:: op, and comm_replay only uses record_param_comms 5. generate_io_tensor for CommsReplayManager in et_replay.py has been removed temporarily, it does not handle all collectives correctly for creating input/output tensors. The current version uses comms_replay to allocate the tensors. We can put it back when that function is ready. 6. Some other minor fixes, for example, use logger instead of print Reviewed By: briancoutinho Differential Revision: D61055957 fbshipit-source-id: 34ca74b221b3525b4e2a81df59df60f8924253c5 --- et_replay/et_replay_utils.py | 28 +- et_replay/execution_trace.py | 50 +- et_replay/tests/test_execution_trace.py | 4 +- et_replay/tools/comm_replay.py | 3 +- et_replay/tools/et_replay.py | 834 +++++++++++++----------- 5 files changed, 483 insertions(+), 436 deletions(-) diff --git a/et_replay/et_replay_utils.py b/et_replay/et_replay_utils.py index 5013bb62..b0cf732a 100644 --- a/et_replay/et_replay_utils.py +++ b/et_replay/et_replay_utils.py @@ -23,7 +23,7 @@ "float": (torch.float32, torch.randn), "double": (torch.float64, torch.randn), "signed char": (torch.int8, torch.ones), - "unsigned char": (torch.int8, torch.ones), + "unsigned char": (torch.uint8, torch.ones), "c10::Half": (torch.half, torch.ones), "c10::BFloat16": (torch.bfloat16, torch.ones), } @@ -38,7 +38,7 @@ "float": ("torch.float32", "torch.randn"), "double": ("torch.float64", "torch.randn"), "signed char": ("torch.int8", "torch.ones"), - "unsigned char": ("torch.int8", "torch.ones"), + "unsigned char": ("torch.uint8", "torch.ones"), "c10::Half": ("torch.half", "torch.ones"), "c10::BFloat16": ("torch.bfloat16", "torch.ones"), } @@ -59,15 +59,17 @@ } -def is_tensor_list(n, idx): - return isinstance(idx, int) and "GenericList[Tensor" in n.input_types[idx] +def is_tensor_list(n, idx, is_input): + types_list = n.input_types if is_input else n.output_types + return isinstance(idx, int) and "GenericList[Tensor" in types_list[idx] -def is_tensor(n, idx): +def is_tensor(n, idx, is_input): + types_list = n.input_types if is_input else n.output_types return ( isinstance(idx, int) - and "Tensor" in n.input_types[idx] - and "GenericList" not in n.input_types[idx] + and "Tensor" in types_list[idx] + and "GenericList" not in types_list[idx] ) @@ -166,22 +168,10 @@ def is_qualified(op): def get_input_tensors(n): - if is_fbgemm_forward(n): - idx_list = fbgemm_input_args_indices(n) - return zip( - [n.input_types[x] for x in idx_list], - [ - tuple(n.inputs[x]) if isinstance(n.inputs[x], list) else n.inputs[x] - for x in idx_list - ], - [n.input_shapes[x] for x in idx_list], - ) return n.get_input_tensors() def get_output_tensors(n): - if is_fbgemm_forward(n): - return list(zip(n.output_types, [tuple(x) for x in n.outputs], n.output_shapes)) return n.get_output_tensors() diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index ea2769b4..b890ef38 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -35,37 +35,6 @@ class NodeType(Enum): LABEL = 2 -# Label markers -LABEL_MARKERS = [ - "##", - "__", - "module::", - "DLRM ", - "DistributedDataParallel", - "Profiler", - "[pytorch|", - "forward", - "backward", - "Optimizer.zero_grad", - "[param", - "", - "reduce-grads", - "multiply-grads", - "clip-grads", - "optimizer", - "gans_torchscript_ops::", - "split_with_sizes", - "chunk", - "All2All_Pooled_ReqBackward", - "All2All_Pooled_Req", - "All2All_Pooled_Wait", - "c10d::", - "TorchDynamo Cache Lookup", - "CompiledFunction", - "Torch-Compiled Region", -] - - """ TensorNode @@ -180,7 +149,7 @@ def __init__( self.fw_parent_id: int = fw_parent_id self.seq_id: int = seq_id self.scope: int = scope - self.type: NodeType = self.detect_type(name, inputs, outputs) + self.type: NodeType = self.detect_type() # self.inputs: List[Any] = [tuple(i) if isinstance(i, list) else i for i in inputs] self.inputs: List[Any] = inputs self.input_types: List[str] = input_types @@ -280,16 +249,19 @@ def get_parent_by_name(self, names) -> Optional[Node]: return node return None - def detect_type(self, name: str, inputs: List[Any], outputs: List[Any]) -> NodeType: + def detect_type(self) -> NodeType: if ( - any(name.startswith(x) for x in LABEL_MARKERS) - # and not outputs + # for collectives, ET records both c10d::collective_function and + # record_param_comms. Only record_param_comms is used for replay + self.name == "record_param_comms" + or ( + self.op_schema != "" and not self.name.startswith("c10d::") + ) # for aten ops + or self.kernel_backend == "triton" # for PT2 triton kernels ): - # if outputs: - # print(f"{name} has outputs, not expected.") - return NodeType.LABEL - else: return NodeType.OPERATOR + else: + return NodeType.LABEL def get_tensors(self, param_list: Iterable) -> List[tuple]: tensors = [] diff --git a/et_replay/tests/test_execution_trace.py b/et_replay/tests/test_execution_trace.py index 8a31a6c1..ad77ea3e 100644 --- a/et_replay/tests/test_execution_trace.py +++ b/et_replay/tests/test_execution_trace.py @@ -30,7 +30,7 @@ def test_trace_load_resnet_1gpu_ptorch_1_0_3(self): ) t, et = self._test_and_validate_trace(et_file) self.assertGreater(t.num_ops(), 1000) - self.assertEqual(t.num_comm_ops(), 12) + self.assertEqual(t.num_comm_ops(), 27) self.assertEqual(t.num_triton_ops(), 0) def test_trace_load_resnet_2gpu_ptorch_1_1_0(self): @@ -39,7 +39,7 @@ def test_trace_load_resnet_2gpu_ptorch_1_1_0(self): ) t, et = self._test_and_validate_trace(et_file) self.assertGreater(t.num_ops(), 1000) - self.assertEqual(t.num_comm_ops(), 12) + self.assertEqual(t.num_comm_ops(), 27) self.assertEqual(t.num_triton_ops(), 0) diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 4fbc328f..5706b6f9 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -36,6 +36,7 @@ pass logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) # sleep for 20ms to wait for next collective LOOP_TIMER_S = 0.02 @@ -640,7 +641,7 @@ def prepComms( curComm: commsArgs, commsParams: commsParamsHolderBase, regenerateTensors: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor]]: """ Prepares the appropriate tensors for the current collective communication. diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 81eede84..d7409702 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from datetime import datetime -from typing import Dict, Tuple +from typing import Dict import numpy as np import torch @@ -19,7 +19,6 @@ commsParamsHolderBase, ) from et_replay.et_replay_utils import ( - build_fbgemm_func, build_torchscript_func, build_triton_func, fbgemm_input_args_indices, @@ -28,17 +27,14 @@ generate_suffix, get_input_tensors, get_output_tensors, - is_fbgemm_backward, is_fbgemm_forward, is_fbgemm_forward_unweighted, - is_qualified, is_tensor, is_tensor_list, - skip_op, TORCH_DTYPES_RNG, TORCH_DTYPES_RNG_str, ) -from et_replay.execution_trace import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace, NodeType from et_replay.tools.comm_replay import commsTraceReplayBench from et_replay.utils import trace_handler from param_bench.train.compute.python.lib import pytorch as lib_pytorch @@ -51,6 +47,9 @@ from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid # noqa from torch.profiler import ExecutionTraceObserver +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class CommsReplayManager(commsTraceReplayBench): # pyre-ignore[13]: def __init__(self): @@ -58,16 +57,21 @@ def __init__(self): self.comp_replay_manager = None + """ + TODO: some of the collectives need a list of tensors as input/output + invetigate how to support it def generate_io_tensors( self, curComm: commsArgs, commsParams: commsParamsHolderBase, regenerateTensors: bool, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor]]: + return super().prepComm(curComm, commsParams, regenerateTensors) + node = self.comp_replay_manager.et.nodes[curComm.id] - input_tensors = self.comp_replay_manager.get_inputs(node) - output_tensors = self.comp_replay_manager.get_comm_outputs(node) + input_tensors, _ = self.comp_replay_manager.get_data(node, True) + output_tensors, _ = self.comp_replay_manager.get_data(node, False) ip_tensor = None op_tensor = None @@ -80,13 +84,14 @@ def extract_tensor_from_list(tensor_list): else: return tensor_list - if len(input_tensors) > 0: + if input_tensors is not None and len(input_tensors) > 0: ip_tensor = extract_tensor_from_list(input_tensors[0]) - if len(output_tensors) > 0: + if output_tensors is not None and len(output_tensors) > 0: op_tensor = extract_tensor_from_list(output_tensors[0]) return (ip_tensor, op_tensor) + """ class ExgrReplayManager: @@ -97,7 +102,6 @@ def __init__(self): self.profile_memory = False self.et = None self.et_profile = False - self.batch_size = 1 self.cuda_id = 0 self.debug = False self.compute_only = False @@ -151,12 +155,6 @@ def __init__(self): # Tensors that should be instantiated on cpu, e.g., input of aten::pin_memory and aten::to. self.cpu_tensor = set() - # Skip the node if their names contain any of the following strings. - self.skip_node_names = [ - "DataLoader", - "aten::set_", - ] - self.parallel_nodes_parents = [] # Ids of nodes that need to run in parallel. self.parallel_nodes_ids = [] @@ -164,33 +162,33 @@ def __init__(self): # This is used to pick out a single iteration when trace contains multiple iterations. # Basically this label should be captured at the beginning of each iteration so that one iteration # is between two consecutive label nodes. - self.label = "ProfilerStep#" + self.profile_step_label = "ProfilerStep#" try: from param_bench.et_replay.fb.internals import ( add_internal_parallel_nodes_parents, - add_internal_skip_nodes, ) except ImportError: - logging.info("FB internals not present") + logger.info("FB internals not present") else: - self.skip_node_names = add_internal_skip_nodes(self.skip_node_names) self.parallel_nodes_parents = add_internal_parallel_nodes_parents( self.parallel_nodes_parents ) + self.profile_step_node_ids = [] + # Only use for memory profile. self.current_allocated_mem = 0 self.current_reserved_mem = 0 self.op_allocated_mem = {} self.op_reserved_mem = {} - # Store the backward fbgemm ops generated in the forward. - self.fbgemm_backward_ops = [] - - # Debug use, record the nodes we skip. - self.actual_skip_nodes = [] - self.actual_skip_nodes_cnt = 0 + # actual_skip_nodes is a dictinary which records the skipped node name + # and the reason to skip it. + # Dict {name => skip reason} + self.actual_skip_nodes: Dict[str, str] = {} + self.initial_skip_node_count = 0 + self.n_skipped_nodes = 0 self.tensor_with_device = True # A tensor may appear on multiple devices but here we only store the first device for initialization @@ -213,8 +211,6 @@ def __init__(self): self.exec_time = [] self.setup_time = [] - self.operators_count = [] - self.tf32 = False # Tensors that related to aten::to, for those tensors we still need to override its value after @@ -230,7 +226,6 @@ def initBench(self): self.profile_replay = self.args.profile_replay self.profile_memory = self.args.profile_memory self.et_profile = self.args.et - self.batch_size = self.args.batch_size self.cuda_id = self.args.cuda self.debug = self.args.debug self.compute_only = self.args.compute @@ -248,7 +243,7 @@ def initBench(self): try: from param_bench.et_replay.fb.internals import read_remote_trace except ImportError: - logging.info("FB internals not present") + logger.info("FB internals not present") exit(1) else: et, self.trace_file = read_remote_trace(self.args.input) @@ -265,21 +260,34 @@ def initBench(self): self.dump_path += "benchmark.py" # Multiple traces. else: - print(f"{os.getpid()} is rank{self.comms_env_params['global_rank']}") + logger.info( + f"process {os.getpid()} is rank {self.comms_env_params['global_rank']}" + ) self.cuda_id = self.comms_env_params["local_rank"] self.cuda = f"cuda:{self.comms_env_params['local_rank']}" # Different processes should read different traces based on global_rank_id. if "://" in self.args.trace_path: try: - from param_bench.et_replay.fb.internals import read_remote_trace + from param_bench.et_replay.fb.internals import ( + read_remote_skip_node_file, + read_remote_trace, + ) except ImportError: - logging.info("FB internals not present") + logger.info("FB internals not present") exit(1) else: et, self.trace_file = read_remote_trace( f"{self.args.trace_path}/rank-{self.comms_env_params['global_rank']}.json" ) self.et = ExecutionTrace(json.load(et)) + + # check if the remote path has skip-node.json + skip_node_json = read_remote_skip_node_file( + f"{self.args.trace_path}" + ) + if skip_node_json: + self.actual_skip_nodes = json.load(skip_node_json) + self.initial_skip_node_count = len(self.actual_skip_nodes) else: self.trace_file = f"{self.args.trace_path}/rank-{self.comms_env_params['global_rank']}.json" with open(self.trace_file, "r") as f: @@ -292,6 +300,17 @@ def initBench(self): self.resource_dir = os.path.join( base_path, os.path.splitext(file_name)[-2] + "_resources" ) + + if self.args.skip_node_file: + try: + with open(self.args.skip_node_file, "r") as json_file: + self.actual_skip_nodes = json.load(json_file) + self.initial_skip_node_count = len(self.actual_skip_nodes) + except OSError: + logger.info( + f"Failed to load skip node file {self.args.skip_node_file}." + ) + self.kernel_map = {} self.async_compile = AsyncCompile() @@ -347,6 +366,17 @@ def reset_registry(self): gc.collect() torch.cuda.empty_cache() + def add_skipped_nodes(self, node, reason: str) -> None: + if node.name not in self.actual_skip_nodes: + self.actual_skip_nodes[node.name] = reason + + def is_skipped(self, node) -> bool: + if node.name in self.actual_skip_nodes: + self.n_skipped_nodes += 1 + return True + else: + return False + def extract_subgraph(self, root): """ return: all nodes in the subgraph, in the order of node ID (also execution) @@ -375,46 +405,51 @@ def analayze_node(node): continue self.input_tensor_ids.add(t_id) - func, output_count = self.build_func(node) - self.funcs[node.id] = (func, output_count) - - def dfs_traverse(root): - for child in root.children: - try: - if self.label and self.label in child.name: - self.sorted_nodes.append(child) + if node.name == "record_param_comms": + # Node "record_param_comms" is not able to have a func created by self.build_func + # but we still want to return success to keep it in self.sorted_nodes + return True, "" - if any(x in child.name for x in self.skip_node_names): - self.actual_skip_nodes.append(child.name) - self.actual_skip_nodes_cnt += 1 - continue + func, output_count = self.build_func(node) + if func is None: + return False, "Failed to build function" + else: + self.funcs[node.id] = (func, output_count) + return True, "" + + def dfs_traverse(node): + if self.profile_step_label in node.name: + self.profile_step_node_ids.append(node.id) + if node.type == NodeType.OPERATOR: + if not self.is_skipped(node): + self.sorted_nodes.append(node) + return - if is_qualified(child): - self.sorted_nodes.append(child) - dfs_traverse( - child - ) # temporaryly, the 'is_qualified' strategy needs to be refactored - else: - if skip_op(child): - self.actual_skip_nodes.append(child.name) - self.actual_skip_nodes_cnt += 1 - dfs_traverse(child) - except Exception as e: - print(f"Graph parse error: {e}, node id: {child.id}") - exit(1) + for child in node.children: + dfs_traverse(child) + self.n_skipped_nodes = 0 dfs_traverse(root) self.sorted_nodes = sorted(self.sorted_nodes, key=lambda x: x.id) - for i in range(len(self.sorted_nodes)): - if self.label and self.label in self.sorted_nodes[i].name: - self.operators_count.append(i) - if len(self.operators_count) > 1: - self.sorted_nodes = self.sorted_nodes[ - self.operators_count[0] + 1 : self.operators_count[1] + self.profile_step_node_ids = sorted(self.profile_step_node_ids) + + if len(self.profile_step_node_ids) > 1: + # Only execute the ops in the first step + start_id = self.profile_step_node_ids[0] + end_id = self.profile_step_node_ids[1] + self.sorted_nodes = [ + x for x in self.sorted_nodes if x.id > start_id and x.id < end_id ] - print("#Operators to execute: ", len(self.sorted_nodes)) + + logger.info(f"#Operators to execute: {len(self.sorted_nodes)}") + picked_nodes = [] for node in self.sorted_nodes: - analayze_node(node) + success, msg = analayze_node(node) + if success: + picked_nodes.append(node) + else: + self.add_skipped_nodes(node, msg) + self.sorted_nodes = picked_nodes # triton kernels are compiled in parallel, need to wait until # all kernels are compiled. @@ -423,6 +458,12 @@ def dfs_traverse(root): self.select_parallel_nodes() + if self.args.skip_node_file is not None and self.args.update_skip_node_file: + actual_skip_node_count = len(self.actual_skip_nodes) + if actual_skip_node_count > self.initial_skip_node_count: + with open(self.args.skip_node_file, "w") as outfile: + json.dump(self.actual_skip_nodes, outfile, indent=4) + def select_parallel_nodes(self): def is_parallel_parent(node): return node.name in self.parallel_nodes_parents @@ -571,7 +612,9 @@ def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): def allocate_tensors(self): start_ns = time.time_ns() - + """ + TODO: before figure out how to allocate tensors for comm nodes, we use + comms replay to allocate tensors for comm nodes. if not (self.compute_only or self.args.separate): for node in self.sorted_nodes: if node.name == "record_param_comms": @@ -582,8 +625,14 @@ def allocate_tensors(self): for node in self.sorted_nodes: if node.name != "record_param_comms": self.allocate_comp_tensors(node) + """ - print(f"Tensor allocation time: {(time.time_ns() - start_ns) / 1000000.0} ms") + for node in self.sorted_nodes: + if node.name != "record_param_comms": + self.allocate_comp_tensors(node) + logger.info( + f"Tensor allocation time: {(time.time_ns() - start_ns) / 1000000.0} ms" + ) def allocate_comm_tensors(self, node): def add_comm_tensor_registry(tensor_strides, tensors): @@ -599,12 +648,9 @@ def add_comm_tensor_registry(tensor_strides, tensors): and replay_t_id in self.instantiate ): try: - if data_type == "Tensor(signed char)": - dtype, _ = TORCH_DTYPES_RNG["signed char"] - else: - dtype, _ = TORCH_DTYPES_RNG[ - data_type.lstrip("Tensor(").rstrip(")") - ] + dtype, _ = TORCH_DTYPES_RNG[ + data_type.lstrip("Tensor(").rstrip(")") + ] strides = None if tensor_strides is not None: @@ -621,7 +667,7 @@ def add_comm_tensor_registry(tensor_strides, tensors): self.tensor_registry_permanent[replay_t_id] = tensor except KeyError: if data_type != "Tensor(nullptr (uninitialized))": - print("KeyError: ", node.id, t_id, data_type) + logger.info(f"KeyError: {node.id}, {t_id}, {data_type}") self.tensor_registry_permanent[replay_t_id] = None add_comm_tensor_registry( @@ -632,55 +678,45 @@ def add_comm_tensor_registry(tensor_strides, tensors): ) def allocate_comp_tensors(self, node): # noqa: C901 - if is_fbgemm_forward(node): - if self.cpu: - input_args, _ = generate_fbgemm_tensors( - node, - "cpu", - self.args.rows, - self.args.pooling_factor, - self.args.alpha, - ) - else: - input_args, _ = generate_fbgemm_tensors( - node, - self.cuda, - self.args.rows, - self.args.pooling_factor, - self.args.alpha, - ) tensor_strides = node.get_input_tensor_strides() - for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): + input_tensors = get_input_tensors(node) + node.pre_load_tensors = [None] * len(input_tensors) + for idx, (data_type, t_id, shape) in enumerate(input_tensors): device = self.device if self.tensor_with_device: device = t_id[5] t_id = tuple(list(t_id)[:5]) replay_t_id = self.tensors_mapping[(node.id, t_id, True)] - if ( - t_id in self.input_tensor_ids - and replay_t_id not in self.tensor_registry_permanent.keys() - and ( - node.name == "aten::embedding_bag" - or "fbgemm::split_embedding_codegen_lookup" in node.name - or replay_t_id in self.instantiate - ) - ): + if t_id in self.input_tensor_ids: try: - if is_fbgemm_forward(node): - self.tensor_registry_permanent[replay_t_id] = input_args[idx] - if "fbgemm::split_embedding_codegen_lookup" in node.name: - self.unchangeable_intermediate_tensors.add(replay_t_id) - else: - if data_type == "Tensor(signed char)": - dtype, _ = TORCH_DTYPES_RNG["signed char"] - else: - dtype, _ = TORCH_DTYPES_RNG[ - data_type.lstrip("Tensor(").rstrip(")") - ] - - strides = None - if node.input_strides is not None: - strides = tensor_strides[idx] + dtype, _ = TORCH_DTYPES_RNG[data_type.lstrip("Tensor(").rstrip(")")] + + strides = None + if node.input_strides is not None: + strides = tensor_strides[idx] + + tensor = None + if dtype in ( + torch.int8, + torch.uint8, + torch.int16, + torch.uint16, + torch.int32, + torch.uint32, + torch.int64, + torch.uint64, + torch.long, + ): + tensor = self.get_tensor_from_file( + node.id, idx, device, shape, dtype, strides + ) + if tensor is not None: + node.pre_load_tensors[idx] = tensor + if ( + tensor is None + and replay_t_id in self.instantiate + and replay_t_id not in self.tensor_registry_permanent.keys() + ): tensor = self.get_tensor_from_storage( t_id[1], # storage_id t_id[2], # offset @@ -691,46 +727,13 @@ def allocate_comp_tensors(self, node): # noqa: C901 strides, ) self.tensor_registry_permanent[replay_t_id] = tensor - if node.name == "aten::embedding_bag": - self.unchangeable_intermediate_tensors.add(replay_t_id) - if node.name == "aten::pin_memory" and idx == 0: - self.cpu_tensor.add(replay_t_id) + except KeyError: if data_type != "Tensor(nullptr (uninitialized))": - print("KeyError: ", node.id, t_id, data_type) + logger.info(f"KeyError: {node.id}, {t_id}, {data_type}") self.tensor_registry_permanent[replay_t_id] = None - ###### - # Workaround to match offsets for embedding table - # Currently assume a uniform distribution. - if node.name == "aten::embedding_bag": - indices_tensor_shape = node.input_shapes[1][0] - offsets_tensor_shape = node.input_shapes[2][0] - nnz = indices_tensor_shape / offsets_tensor_shape - for i in range(offsets_tensor_shape): - if self.tensor_with_device: - self.tensor_registry_permanent[ - self.tensors_mapping[(node.id, tuple(node.inputs[2][:5]), True)] - ][i] = (i * nnz) - else: - self.tensor_registry_permanent[ - self.tensors_mapping[(node.id, tuple(node.inputs[2]), True)] - ][i] = (i * nnz) - ###### - def build_func(self, node): - if is_fbgemm_forward(node): - if self.cpu: - func, output_count = build_fbgemm_func(node, "cpu", self.args.rows) - else: - func, output_count = build_fbgemm_func(node, self.cuda, self.args.rows) - self.fbgemm_backward_ops.append((func.backward, node.id)) - return func.forward, output_count - elif is_fbgemm_backward(node): - assert self.fbgemm_backward_ops - backward_op, forward_id = self.fbgemm_backward_ops.pop(-1) - return backward_op, len(node.output_types) - if node.kernel_backend == "triton": if node.kernel_file in self.kernel_map: func = self.kernel_map[node.kernel_file] @@ -746,9 +749,6 @@ def build_func(self, node): else: func, output_count = build_torchscript_func(node) - if not func: - self.actual_skip_nodes.append(node.name) - self.actual_skip_nodes_cnt += 1 return func, output_count def generate_code(self): @@ -850,7 +850,7 @@ def _generate_tensor_allocation_str(): self.tensor_registry_permanent[replay_t_id] = 1 except KeyError: if dtype != "Tensor(nullptr (uninitialized))": - print("KeyError: ", node.id, t_id, dtype) + logger.info(f"KeyError: {node.id}, {t_id}, {dtype}") tensor_allocation_str += f"global tensor_{replay_t_id}\n" tensor_allocation_str += f"tensor_{replay_t_id} = None\n" self.tensor_registry_permanent[replay_t_id] = 1 @@ -887,7 +887,7 @@ def _generate_inputs_str(node): ): inputs += "[True, True, True], " continue - if is_tensor(node, idx): + if is_tensor(node, idx, True): if self.tensor_with_device: item = tuple(item[:5]) # Workaround to handle tensor with same id but different data types. @@ -985,7 +985,7 @@ def _parse_element_type(node, output_type, output_tensors, override): outputs = outputs[:-2] return outputs except Exception as e: - print("Generate outputs error: ", e, node.id) + logger.info(f"Generate outputs error: {e}, {node.id}") exit(1) def _generate_run_ops_str(override): @@ -1070,7 +1070,7 @@ def _generate_run_ops_str(override): if self.cpu: code_str += generate_prefix( - self.label, + self.profile_step_label, skip_nodes_str, self.trace_file, "cpu", @@ -1080,7 +1080,7 @@ def _generate_run_ops_str(override): ) else: code_str += generate_prefix( - self.label, + self.profile_step_label, skip_nodes_str, self.trace_file, self.cuda, @@ -1111,18 +1111,77 @@ def _generate_run_ops_str(override): ) if self.dump: - print(f"Intermediate benchmark file dumped to {self.dump_path}") + logger.info(f"Intermediate benchmark file dumped to {self.dump_path}") with open(self.dump_path, "w") as f: print(code_str, file=f) exec(code_str) + def get_tensor_from_file( + self, + node_id, + tensor_index, + device, + shape, + data_type, + strides, + ): + def to_numpy_data_type(data_type): + if data_type == torch.int8: + return np.int8 + elif data_type == torch.uint8: + return np.uint8 + elif data_type == torch.int16: + return np.int16 + elif data_type == torch.uint16: + return np.uint16 + elif data_type == torch.int32: + return np.int32 + elif data_type == torch.uint32: + return np.uint32 + elif data_type == torch.int64 or data_type == torch.long: + return np.int64 + elif data_type == torch.uint64: + return np.uint64 + else: + raise ValueError(f"Unsupported data type: {data_type}") + + device = torch.device(device) + + # check if the tensor data file exists + storage_fn = ( + self.resource_dir + + "/nid_" + + str(node_id) + + "_tid_" + + str(tensor_index) + + ".dat" + ) + if os.path.isfile(storage_fn): + np_x = np.fromfile(storage_fn, dtype=to_numpy_data_type(data_type)) + if len(shape) == 0: + np_x = np_x.item() + x = torch.tensor(np_x) + else: + np_x = np.reshape(np_x, shape) + x = torch.from_numpy(np_x) + if device != torch.device("cpu"): + x = x.cuda(device) + return x + else: + return None + def get_tensor_from_storage( - self, storage_id, data_offset, elem_bytes, device, shape, data_type, strides + self, + storage_id, + data_offset, + elem_bytes, + device, + shape, + data_type, + strides, ): assert storage_id in self.tensor_storage_map - tensor_data = self.tensor_storage_map[storage_id] - device = torch.device(device) if device not in tensor_data[1]: if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: storage_tensor = torch.rand( @@ -1132,10 +1191,13 @@ def get_tensor_from_storage( storage_tensor = torch.ones( (tensor_data[0] // elem_bytes), dtype=data_type, device=device ) + tensor_data[1][device] = storage_tensor else: storage_tensor = tensor_data[1][device] + x = torch.empty(0, dtype=data_type) + device = torch.device(device) if device != torch.device("cpu"): x = x.cuda(device) if strides is None: @@ -1154,82 +1216,72 @@ def get_tensor_from_storage( return x - # TODO: refactor the code in get_inputs and get_comm_outputs to one function. The code for - # fbgemm in get_inputs will be cleaned up soon - def get_inputs(self, node): + def get_data(self, node, is_input): try: - if is_fbgemm_forward(node): - idx_list = fbgemm_input_args_indices(node) - if self.tensor_with_device: - inputs = [ - self.tensor_registry[ - self.tensors_mapping[ - (node.id, tuple(node.inputs[idx][:5]), True) - ] - ] - for idx in idx_list - ] - else: - inputs = [ - self.tensor_registry[ - self.tensors_mapping[ - (node.id, tuple(node.inputs[idx]), True) - ] - ] - for idx in idx_list - ] - if is_fbgemm_forward_unweighted(node): - inputs.append(None) + if is_input: + data_in = node.inputs else: - inputs = [] - for idx, item in enumerate(node.inputs): - if is_tensor(node, idx): + data_in = node.outputs + data_out = [] + tensor_index = 0 + for idx, item in enumerate(data_in): + if is_tensor(node, idx, is_input): + if ( + is_input + and hasattr(node, "pre_load_tensors") + and node.pre_load_tensors[tensor_index] is not None + ): + data_out.append(node.pre_load_tensors[tensor_index]) + else: self.lookup_cnt += 1 if self.tensor_with_device: item = tuple(item[:5]) - inputs.append( + data_out.append( self.tensor_registry[ self.tensors_mapping[(node.id, tuple(item), True)] ] ) - elif is_tensor_list(node, idx): - self.lookup_cnt += len(item) - if self.tensor_with_device: - inputs.append( - [ - self.tensor_registry[ - self.tensors_mapping[ - (node.id, tuple(t_id[:5]), True) - ] - ] - for t_id in item - ] - ) + + tensor_index += 1 + + elif is_tensor_list(node, idx, is_input): + self.lookup_cnt += len(item) + tensor_list = [] + for t_id in item: + if ( + is_input + and hasattr(node, "pre_load_tensors") + and node.pre_load_tensors[tensor_index] is not None + ): + tensor_list.append(node.pre_load_tensors[tensor_index]) else: - inputs.append( - [ - self.tensor_registry[ - self.tensors_mapping[ - (node.id, tuple(t_id), True) - ] - ] - for t_id in item + if self.tensor_with_device: + t_id = tuple(t_id[:5]) + else: + t_id = tuple(t_id) + tensor_list.append( + self.tensor_registry[ + self.tensors_mapping[(node.id, t_id, True)] ] ) - elif item == "" or item == "": - inputs.append(None) - elif item == "inf" or item == "-inf": - inputs.append(float(item)) - elif node.input_types[idx] == "Device" and "cuda" in item: - if self.cpu: - inputs.append("cpu") - else: - inputs.append(self.cuda) + + tensor_index += 1 + data_out.append(tensor_list) + + elif item == "" or item == "": + data_out.append(None) + elif item == "inf" or item == "-inf": + data_out.append(float(item)) + elif node.input_types[idx] == "Device" and "cuda" in item: + if self.cpu: + data_out.append("cpu") else: - inputs.append(item) - return inputs + data_out.append(self.cuda) + else: + data_out.append(item) + return data_out, "" except Exception as e: - print(f"Inputs error: {e} at node: {node.id}") + return None, f"Inputs error: {e}" def get_comm_outputs(self, node): try: @@ -1279,12 +1331,21 @@ def get_comm_outputs(self, node): outputs.append(item) return outputs except Exception as e: - print(f"Outputs error: {e} at node: {node.id}") + logger.info(f"Outputs error: {e} at node: {node.id}") def run_op(self, node, iter, cnt): # noqa: C901 if isinstance(node, commsArgs): - et_node = self.et.nodes[node.id] + if self.debug and iter >= self.numWarmupIters: + start_ns = time.time_ns() + before_execution = start_ns + self.commsBench.replaySingle(self.commsParams, node, cnt) + if self.debug and iter >= self.numWarmupIters: + after_execution = time.time_ns() + """ + TODO: before figure out how to allocate tensor for collectives, comms replay + will handle tensor allocation and deallocation. + et_node = self.et.nodes[node.id] for _, t_id, _ in get_input_tensors(et_node) + get_output_tensors(et_node): if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) @@ -1294,34 +1355,27 @@ def run_op(self, node, iter, cnt): # noqa: C901 and replay_t_id not in self.instantiate ): del self.tensor_registry[replay_t_id] + """ + return True, "" else: - if node.name == "record_param_comms": - return + # This is a comms node and it is handled by commsBench.replaySingle + if node.name == "record_param_comms" or node.name.startswith("c10d::"): + return True, "" if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() func, output_count = self.funcs[node.id] - if not func: - return - inputs = self.get_inputs(node) + inputs, msg = self.get_data(node, True) + if msg != "": + return False, msg + + # TODO: why need this hack? # Workaround to eliminate the "strides() called on undefined Tensor" error. if node.name == "aten::convolution_backward": inputs[-1] = [True, True, True] - # Workaround to handle tensor with same id but different data types (ads_cmf10x_single_iter_512_newest_eg.json). - if node.name == "aten::index_add_": - inputs[3] = inputs[3].to(torch.float64) - inputs[2] = inputs[2].to(torch.int) - if node.name == "aten::index_copy_": - if node.input_types[3] == "Tensor(double)": - inputs[3] = inputs[3].to(torch.float64) - if node.input_types[2] == "Tensor(long)": - inputs[2] = inputs[2].to(torch.int64) - if node.name == "aten::index_select": - inputs[2] = inputs[2].to(torch.int) - if self.debug and iter >= self.numWarmupIters: before_execution = time.time_ns() @@ -1342,65 +1396,52 @@ def run_op(self, node, iter, cnt): # noqa: C901 # Flatten any tensor lists # TODO: Simplify this if not tmp: - print(f"Not expect that {node.id} has no output.") + logger.info(f"Not expect that {node.id} has no output.") return for x in tmp: if isinstance(x, list) and isinstance(x[0], torch.Tensor): outputs.extend(x) elif isinstance(x, torch.Tensor): outputs.append(x) - except Exception as e: - print( - f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" - ) - exit(1) - if node.name == "aten::repeat_interleave": - current_len = node.input_shapes[0][0] - target_len = node.output_shapes[0][0] - if current_len < target_len: - dtype, _ = TORCH_DTYPES_RNG[ - node.output_types[0].lstrip("Tensor(").rstrip(")") - ] - tmp = ( - torch.zeros(target_len - current_len) - .to(dtype) - .cuda(self.device) - ) - outputs[0] = torch.cat((tmp, outputs[0])) + if self.debug and iter >= self.numWarmupIters: + after_execution = time.time_ns() - if self.debug and iter >= self.numWarmupIters: - after_execution = time.time_ns() - - for _, t_id, _ in get_input_tensors(node): - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - replay_t_id = self.tensors_mapping[(node.id, t_id, True)] - if ( - node.id >= self.replay_tensor_id_to_last_node_id_map[replay_t_id] - and replay_t_id not in self.instantiate - ): - del self.tensor_registry[replay_t_id] - - for (_, t_id, _), output in zip(get_output_tensors(node), outputs): - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - - if t_id in self.input_tensor_ids: - replay_t_id = self.tensors_mapping[(node.id, t_id, False)] + for _, t_id, _ in get_input_tensors(node): + if self.tensor_with_device: + t_id = tuple(list(t_id)[:5]) + replay_t_id = self.tensors_mapping[(node.id, t_id, True)] if ( - replay_t_id not in self.unchangeable_intermediate_tensors + node.id + >= self.replay_tensor_id_to_last_node_id_map[replay_t_id] and replay_t_id not in self.instantiate ): + del self.tensor_registry[replay_t_id] + + for (_, t_id, _), output in zip(get_output_tensors(node), outputs): + if self.tensor_with_device: + t_id = tuple(list(t_id)[:5]) + + if t_id in self.input_tensor_ids: + replay_t_id = self.tensors_mapping[(node.id, t_id, False)] if ( - node.id - < self.replay_tensor_id_to_last_node_id_map[replay_t_id] + replay_t_id not in self.unchangeable_intermediate_tensors + and replay_t_id not in self.instantiate ): - self.tensor_registry[replay_t_id] = output - else: - del output - else: - del output + if ( + node.id + < self.replay_tensor_id_to_last_node_id_map[replay_t_id] + ): + self.tensor_registry[replay_t_id] = output + else: + del output + else: + del output + + except Exception as e: + msg = f"Run op exception Error: {e}, node id: {node.id}, node name: {node.name}" + logger.error(msg) + return False, msg if self.profile_memory: self.op_allocated_mem[node] = ( @@ -1418,13 +1459,15 @@ def run_op(self, node, iter, cnt): # noqa: C901 ) self.exec_time.append(after_execution - before_execution) + return True, "" + def init_comms(self): comms_env_params = comms_utils.read_comms_env_vars() - print(comms_env_params, self.cuda) + logger.info(f"{comms_env_params}, {self.cuda}") self.commsBench = CommsReplayManager() self.commsBench.comp_replay_manager = self - self.commsBench.trace_file = self.trace_file + self.commsBench.trace_file = self.args.trace_path if "://" in self.trace_file: self.commsBench.use_remote_trace = True @@ -1451,6 +1494,48 @@ def init_comms(self): self.commsBench.initBench(self.commsParams, comms_args) self.commsBench.replayInit(self.commsParams) + def remove_op_with_runtime_error(self): + for cnt, node in enumerate(self.sorted_nodes): + success, msg = self.run_op(node, 0, cnt) + if success: + continue + + if ( + msg.find("RuntimeError: CUDA error") != -1 + or msg.find("torch.OutOfMemoryError") != -1 + ): + logger.info(f"Can not keep replaying due to {msg}") + self.add_skipped_nodes(node, msg) + break + + # The current node failed, if other nodes depend on the output of this + # node, these nodes will also fail. Allocate the output tensors so other nodes + # can keep playing. + for data_type, t_id, shape in get_output_tensors(node): + if self.tensor_with_device: + t_id = tuple(list(t_id)[:5]) + if t_id not in self.input_tensor_ids: + continue + + if data_type == "Tensor(nullptr (uninitialized))": + t = None + else: + dtype, rng = TORCH_DTYPES_RNG[ + data_type.lstrip("Tensor(").rstrip(")") + ] + replay_t_id = self.tensors_mapping[(node.id, t_id, False)] + t = rng(shape).to(dtype) + if self.tensor_with_device: + if self.tensor_device[replay_t_id] != "cpu" and not self.cpu: + t.cuda(self.tensor_device[replay_t_id]) + else: + if not self.cpu: + t.cuda(self.device) + + self.tensor_registry[replay_t_id] = t + + self.add_skipped_nodes(node, msg) + def preprocess_graph(self): if not self.compute_only and not self.generator: self.init_comms() @@ -1465,7 +1550,7 @@ def preprocess_graph(self): find_subgraph = True break if not find_subgraph: - print(f"Cannot find subgraph with name {self.args.subgraph}.") + logger.info(f"Cannot find subgraph with name {self.args.subgraph}.") exit(1) else: root = nodes[1] # 1-base @@ -1480,7 +1565,7 @@ def preprocess_graph(self): for tensor in self.tensor_shapes: if len(self.tensor_shapes[tensor]) != 1: tensor_with_multiple_shape_count += len(self.tensor_shapes[tensor]) - print( + logger.info( f"Tensor count with same identifier but different shapes:{tensor_with_multiple_shape_count}, total tensor: {len(self.tensor_shapes)}" ) @@ -1497,19 +1582,40 @@ def benchTime(self): start_time = datetime.now() self.preprocess_graph() if self.generator: - return - print("Start execution: ") + return 0 + + if self.args.update_skip_node_file: + if os.environ.get("CUDA_LAUNCH_BLOCKING", "0") != "1": + logger.info( + "Please set CUDA_LAUNCH_BLOCKING=1 to get accurate skip node list." + ) + benchmark_result["execution finished"] = False + return benchmark_result + + self.remove_op_with_runtime_error() + actual_skip_node_count = len(self.actual_skip_nodes) + if actual_skip_node_count > self.initial_skip_node_count: + with open(self.args.skip_node_file, "w") as outfile: + json.dump(self.actual_skip_nodes, outfile, indent=4) + benchmark_result["execution finished"] = False + else: + benchmark_result["execution finished"] = True + return benchmark_result + + logger.info("Start execution... ") total_time = 0.0 event_1 = torch.cuda.Event(enable_timing=True) event_2 = torch.cuda.Event(enable_timing=True) - def run_op(event_1, event_2, iter): + def run_ops(event_1, event_2, iter): if not (self.compute_only or self.args.separate): self.commsBench.replayIter = iter event_1.record() for cnt, node in enumerate(self.sorted_nodes): - self.run_op(node, iter, cnt) + success, _ = self.run_op(node, iter, cnt) + if not success: + break event_2.record() if not (self.compute_only or self.args.separate): self.commsBench.resetComms() @@ -1528,7 +1634,7 @@ def run_op(event_1, event_2, iter): gc.collect() torch.cuda.empty_cache() - # Print real time qps every # iterations. + # log real time qps every # iterations. qps_print_interval = 10 prev_iter = self.numWarmupIters @@ -1538,6 +1644,7 @@ def run_iter(iter): nonlocal qps_print_interval nonlocal total_time + logger.info(f"iteration = {iter}") if self.et_profile: if iter == self.numWarmupIters: et.start() @@ -1547,23 +1654,15 @@ def run_iter(iter): if iter == prev_iter: start_ns = time.time_ns() if iter == prev_iter + qps_print_interval: - print( - "Current QPS: ", - int( - self.batch_size - * qps_print_interval - / ((time.time_ns() - start_ns) / 1000000000) - ), + logger.info( + f"Current QPS: {int(qps_print_interval / ((time.time_ns() - start_ns) / 1000000000))}" ) - print( - "Replay {} iterations time: {}ms".format( - qps_print_interval, - (time.time_ns() - start_ns) / 1000000.0, - ) + logger.info( + f"Replay {qps_print_interval} iterations time: {(time.time_ns() - start_ns) / 1000000.0} ms" ) prev_iter = iter start_ns = time.time_ns() - run_op(event_1, event_2, iter) + run_ops(event_1, event_2, iter) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) @@ -1574,12 +1673,12 @@ def run_iter(iter): if not (self.compute_only or self.args.separate): # since the comp replay will pick the 2nd iteration nodes, comm replay also needs - if len(self.operators_count) > 1: + if len(self.profile_step_node_ids) > 1: commNodes = [] for node in self.commsBench.comms_trace[: self.commsBench.max_msg_cnt]: if ( - node.id > self.operators_count[0] - and node.id < self.operators_count[1] + node.id > self.profile_step_node_ids[0] + and node.id < self.profile_step_node_ids[1] ): commNodes.append(node) else: @@ -1588,7 +1687,7 @@ def run_iter(iter): # commNodes is a list of commsArgs, since commsArgs also contains node id, it # can be mixed with sorted_nodes and re-sort them by id, this is an example after sort: # (Node 1000(comp op)) -> (Node 1001(name == "record_param_comms")) -> (commsArgs 1001) -> (Node 1002(comp op)) - # Function run_op(self, node, iter, cnt) will check the type of the input node, if it is a "Node" + # Function run_ops(self, node, iter, cnt) will check the type of the input node, if it is a "Node" # and its name is "record_param_comms", skip it; if it is a "commsArgs", use comm_replay to replay it # TODO: replace the "record_param_comms" node with commsArgs. self.sorted_nodes = self.sorted_nodes + commNodes @@ -1625,15 +1724,15 @@ def run_iter(iter): run_iter(iter) prof.step() benchmark_result["execution finished"] = True - print("Execution finished!") + logger.info("Execution finished!") else: for iter in range(self.numWarmupIters + self.numIters): run_iter(iter) benchmark_result["execution finished"] = True - print("Execution finished!") + logger.info("Execution finished!") if self.profile_memory: - print("Allocated GPU memory(B):") + logger.info("Allocated GPU memory(B):") for node in dict( sorted( self.op_allocated_mem.items(), @@ -1641,49 +1740,39 @@ def run_iter(iter): reverse=True, )[:100] ): - print(node.id, self.op_allocated_mem[node]) - print("Reserved GPU memory(B):") + logger.info(f"{node.id}, {self.op_allocated_mem[node]}") + logger.info("Reserved GPU memory(B):") for node in dict( sorted( self.op_reserved_mem.items(), key=lambda item: item[1], reverse=True )[:100] ): - print(node.id, self.op_reserved_mem[node]) - - print("Replay time per iteration: {:.2f} ms".format(total_time / self.numIters)) - - print( - "Operator coverage: {} / {} = {}".format( - len(self.sorted_nodes), - len(self.sorted_nodes) + self.actual_skip_nodes_cnt, - len(self.sorted_nodes) - / (len(self.sorted_nodes) + self.actual_skip_nodes_cnt), - ) + logger.info(f"{node.id}, {self.op_reserved_mem[node]}") + logger.info("Replay finished") + logger.info(f"Replay time per iteration: {total_time / self.numIters} ms") + logger.info( + f"Operator coverage: {len(self.sorted_nodes)} / {len(self.sorted_nodes) + self.n_skipped_nodes} = {len(self.sorted_nodes) / (len(self.sorted_nodes) + self.n_skipped_nodes)}" ) end_time = datetime.now() try: from param_bench.et_replay.fb.internals import generate_query_url except ImportError: - logging.info("FB internals not present") + logger.info("FB internals not present") else: generate_query_url(start_time, end_time, self.cuda_id) if self.debug: - print("Setup time: {}".format(sum(self.setup_time) / 1000000.0)) - print("Execution time: {}".format(sum(self.exec_time) / 1000000.0)) - - print("Input time: {}".format(self.input_total_time / 1000000.0)) - print("Output time: {}".format(self.output_total_time / 1000000.0)) - print("Lookup count: {}".format(self.lookup_cnt)) - print("Remap tensor list size: ", len(self.tensors_mapping)) - - print( - "Execution time: 50th:{}ms\t90th:{}ms\t95th:{}ms".format( - np.percentile(self.exec_time, 50) / 1000.0, - np.percentile(self.exec_time, 90) / 1000.0, - np.percentile(self.exec_time, 95) / 1000.0, - ) + logger.info(f"Setup time: {sum(self.setup_time) / 1000000.0}") + logger.info(f"Execution time: {sum(self.exec_time) / 1000000.0}") + + logger.info(f"Input time: {self.input_total_time / 1000000.0}") + logger.info(f"Output time: {self.output_total_time / 1000000.0}") + logger.info(f"Lookup count: {self.lookup_cnt}") + logger.info(f"Remap tensor list size: {len(self.tensors_mapping)}") + + logger.info( + f"Execution time: 50th:{np.percentile(self.exec_time, 50) / 1000.0}ms\t90th:{np.percentile(self.exec_time, 90) / 1000.0}ms\t95th:{np.percentile(self.exec_time, 95) / 1000.0}ms" ) if not (self.compute_only or self.args.separate): @@ -1723,12 +1812,6 @@ def readComputeArgs(self, check_args: bool = True): default=False, help="Capture execution trace for replay.", ) - parser.add_argument( - "--batch-size", - type=int, - default=1, - help="Batch size (number of queries) in one replay iteration, used to calculate QPS.", - ) parser.add_argument( "--cuda", type=int, @@ -1794,30 +1877,12 @@ def readComputeArgs(self, check_args: bool = True): default=0, help="Delayed time in ms for wait communication operators.", ) - parser.add_argument( - "--rows", - type=int, - default=1024, - help="Embedding tables rows.", - ) - parser.add_argument( - "--pooling-factor", - type=int, - default=1, - help="Pooling factor when looking up embedding tables.", - ) parser.add_argument( "--tf32", action="store_true", default=False, help="Enable tf32.", ) - parser.add_argument( - "--alpha", - type=float, - default=1, - help="alpha of fbgemm lookup indices zipf distribution.", - ) parser.add_argument( "--cpu", action="store_true", @@ -1830,6 +1895,19 @@ def readComputeArgs(self, check_args: bool = True): default=True, help="when a et_id is being replayed multiple times, setting this to false will use temsors from previous runs.", ) + parser.add_argument( + "--skip-node-file", + type=str, + required=False, + default="", + help="Path to the file that contains the list of nodes to skip.", + ) + parser.add_argument( + "--update-skip-node-file", + action="store_true", + default=False, + help="When true, the node skip list will be updated with the nodes that are skipped during replay.", + ) self.args, _ = parser.parse_known_args() # Check if both 'input' and 'trace_path' are not provided @@ -1848,7 +1926,13 @@ def main(): replay_manager = ExgrReplayManager() replay_manager.readComputeArgs() replay_manager.initBench() - replay_manager.benchTime() + benchmark_result = replay_manager.benchTime() + if benchmark_result["execution finished"]: + logger.info("Replay finished successfully.") + sys.exit(0) + else: + logger.info("Replay failed.") + sys.exit(-1) if __name__ == "__main__":