From a072dadac59888783874efb19ee7683167a22ee6 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Wed, 11 Dec 2024 19:10:59 -0800 Subject: [PATCH] Record min/max of integral tensor in ET (#191) Summary: X-link: https://github.com/pytorch/pytorch/pull/143088 In et-replay, random data is used to run the operators. However, it does not work well for the op that uses index to access tensor. For example, embedding ops, which use the indices to look up the embedding table. If random data is used for these index ops, et-replay usually runs into invalid memory access issue. To fix it, ET provides an environment variable "ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE", if it is set, ET will capture the min/max value of the flattened integral tensor. Then in et_replay, the min/max is used to generate the random tensor within that range. It fixed invalid memory access issue. Differential Revision: D66666931 --- et_replay/execution_trace.py | 16 ++++++++++++++++ et_replay/tools/et_replay.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index c71d0185..d8eec397 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -128,6 +128,7 @@ def __init__( comm_args: _CommArgs | None = None, input_strides: list[Any] | None = None, output_strides: list[Any] | None = None, + tensor_range: str | None = None, ): self.name: str = name self.parent_id: int = parent_id @@ -156,6 +157,11 @@ def __init__( self.output_shapes: list[Any] = output_shapes self.output_strides: list[Any] | None = output_strides self.commArgs: _CommArgs | None = comm_args + self.tensor_range = json.loads(tensor_range) if tensor_range else None + if self.tensor_range is not None: + self.tensor_range = { + int(index): min_max for index, min_max in self.tensor_range.items() + } def get_inputs(self) -> Iterable: return zip(self.input_types, self.inputs, self.input_shapes) @@ -293,6 +299,12 @@ def get_input_tensor_strides(self) -> list[tuple] | None: else: return self.get_tensor_strides(self.get_inputs(), self.input_strides) + def get_input_tensor_range(self, tensor_index) -> tuple | None: + if self.tensor_range is None or tensor_index not in self.tensor_range: + return None + else: + return self.tensor_range[tensor_index] + def get_output_tensor_strides(self) -> list[tuple] | None: if self.output_strides is None: return None @@ -404,6 +416,7 @@ def schema_chakra(self) -> tuple[int, int, int]: "tid": int, "kernel_backend": str, "kernel_file": str, + "tensor_range": str, } @classmethod @@ -477,6 +490,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: tid, kernel_backend, kernel_file, + _, ) = ExecutionTrace._read_attrs(x) comm_attrs = ( @@ -520,6 +534,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: tid, kernel_backend, kernel_file, + tensor_range, ) = ExecutionTrace._read_attrs(x) comm_attrs = ( @@ -551,6 +566,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: comm_attrs, x["inputs"]["strides"], x["outputs"]["strides"], + tensor_range, ) def get_nodes(self, clean: bool = False): diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index ca7b4a5b..f680ffe1 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -662,6 +662,7 @@ def add_comm_tensor_registry(tensor_strides, tensors): shape, dtype, strides, + node.get_input_tensor_range(idx), ) self.tensor_registry_permanent[replay_t_id] = tensor except KeyError: @@ -724,6 +725,7 @@ def allocate_comp_tensors(self, node): # noqa: C901 shape, dtype, strides, + node.get_input_tensor_range(idx), ) self.tensor_registry_permanent[replay_t_id] = tensor @@ -1178,11 +1180,40 @@ def get_tensor_from_storage( shape, data_type, strides, + tensor_range, ): assert storage_id in self.tensor_storage_map tensor_data = self.tensor_storage_map[storage_id] if device not in tensor_data[1]: - if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: + if ( + data_type + in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int, + torch.long, + ] + and tensor_range is not None + ): + storage_tensor = torch.randint( + tensor_range[0], + tensor_range[1] + 1, + (tensor_data[0] // elem_bytes,), + dtype=data_type, + device=device, + ) + elif data_type in [ + torch.half, + torch.float32, + torch.float64, + torch.bfloat16, + ]: storage_tensor = torch.rand( (tensor_data[0] // elem_bytes), dtype=data_type, device=device ) @@ -1369,7 +1400,6 @@ def run_op(self, node, iter, cnt): # noqa: C901 inputs, msg = self.get_data(node, True) if msg != "": return False, msg - # TODO: why need this hack? # Workaround to eliminate the "strides() called on undefined Tensor" error. if node.name == "aten::convolution_backward":