diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index c71d0185..d8eec397 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -128,6 +128,7 @@ def __init__( comm_args: _CommArgs | None = None, input_strides: list[Any] | None = None, output_strides: list[Any] | None = None, + tensor_range: str | None = None, ): self.name: str = name self.parent_id: int = parent_id @@ -156,6 +157,11 @@ def __init__( self.output_shapes: list[Any] = output_shapes self.output_strides: list[Any] | None = output_strides self.commArgs: _CommArgs | None = comm_args + self.tensor_range = json.loads(tensor_range) if tensor_range else None + if self.tensor_range is not None: + self.tensor_range = { + int(index): min_max for index, min_max in self.tensor_range.items() + } def get_inputs(self) -> Iterable: return zip(self.input_types, self.inputs, self.input_shapes) @@ -293,6 +299,12 @@ def get_input_tensor_strides(self) -> list[tuple] | None: else: return self.get_tensor_strides(self.get_inputs(), self.input_strides) + def get_input_tensor_range(self, tensor_index) -> tuple | None: + if self.tensor_range is None or tensor_index not in self.tensor_range: + return None + else: + return self.tensor_range[tensor_index] + def get_output_tensor_strides(self) -> list[tuple] | None: if self.output_strides is None: return None @@ -404,6 +416,7 @@ def schema_chakra(self) -> tuple[int, int, int]: "tid": int, "kernel_backend": str, "kernel_file": str, + "tensor_range": str, } @classmethod @@ -477,6 +490,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: tid, kernel_backend, kernel_file, + _, ) = ExecutionTrace._read_attrs(x) comm_attrs = ( @@ -520,6 +534,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: tid, kernel_backend, kernel_file, + tensor_range, ) = ExecutionTrace._read_attrs(x) comm_attrs = ( @@ -551,6 +566,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node: comm_attrs, x["inputs"]["strides"], x["outputs"]["strides"], + tensor_range, ) def get_nodes(self, clean: bool = False): diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index ca7b4a5b..f680ffe1 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -662,6 +662,7 @@ def add_comm_tensor_registry(tensor_strides, tensors): shape, dtype, strides, + node.get_input_tensor_range(idx), ) self.tensor_registry_permanent[replay_t_id] = tensor except KeyError: @@ -724,6 +725,7 @@ def allocate_comp_tensors(self, node): # noqa: C901 shape, dtype, strides, + node.get_input_tensor_range(idx), ) self.tensor_registry_permanent[replay_t_id] = tensor @@ -1178,11 +1180,40 @@ def get_tensor_from_storage( shape, data_type, strides, + tensor_range, ): assert storage_id in self.tensor_storage_map tensor_data = self.tensor_storage_map[storage_id] if device not in tensor_data[1]: - if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]: + if ( + data_type + in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int, + torch.long, + ] + and tensor_range is not None + ): + storage_tensor = torch.randint( + tensor_range[0], + tensor_range[1] + 1, + (tensor_data[0] // elem_bytes,), + dtype=data_type, + device=device, + ) + elif data_type in [ + torch.half, + torch.float32, + torch.float64, + torch.bfloat16, + ]: storage_tensor = torch.rand( (tensor_data[0] // elem_bytes), dtype=data_type, device=device ) @@ -1369,7 +1400,6 @@ def run_op(self, node, iter, cnt): # noqa: C901 inputs, msg = self.get_data(node, True) if msg != "": return False, msg - # TODO: why need this hack? # Workaround to eliminate the "strides() called on undefined Tensor" error. if node.name == "aten::convolution_backward":