Skip to content

Commit

Permalink
Record min/max of integral tensor in ET (facebookresearch#191)
Browse files Browse the repository at this point in the history
Summary:

X-link: pytorch/pytorch#143088

In et-replay, random data is used to run the operators. However, it does not work well for the op that uses index to access tensor. For example, embedding ops, which use the indices to look up the embedding table. If random data is used for these index ops, et-replay usually runs into invalid memory access issue.

To fix it, ET provides an environment variable "ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE", if it is set, ET will capture the min/max value of the flattened integral tensor. Then in et_replay, the min/max is used to generate the random tensor within that range. It fixed invalid memory access issue.

Reviewed By: sanrise

Differential Revision: D66666931
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Dec 17, 2024
1 parent ca00ca3 commit 01ad990
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
16 changes: 16 additions & 0 deletions et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
comm_args: _CommArgs | None = None,
input_strides: list[Any] | None = None,
output_strides: list[Any] | None = None,
tensor_range: str | None = None,
):
self.name: str = name
self.parent_id: int = parent_id
Expand Down Expand Up @@ -156,6 +157,11 @@ def __init__(
self.output_shapes: list[Any] = output_shapes
self.output_strides: list[Any] | None = output_strides
self.commArgs: _CommArgs | None = comm_args
self.tensor_range = json.loads(tensor_range) if tensor_range else None
if self.tensor_range is not None:
self.tensor_range = {
int(index): min_max for index, min_max in self.tensor_range.items()
}

def get_inputs(self) -> Iterable:
return zip(self.input_types, self.inputs, self.input_shapes)
Expand Down Expand Up @@ -293,6 +299,12 @@ def get_input_tensor_strides(self) -> list[tuple] | None:
else:
return self.get_tensor_strides(self.get_inputs(), self.input_strides)

def get_input_tensor_range(self, tensor_index) -> tuple | None:
if self.tensor_range is None or tensor_index not in self.tensor_range:
return None
else:
return self.tensor_range[tensor_index]

def get_output_tensor_strides(self) -> list[tuple] | None:
if self.output_strides is None:
return None
Expand Down Expand Up @@ -404,6 +416,7 @@ def schema_chakra(self) -> tuple[int, int, int]:
"tid": int,
"kernel_backend": str,
"kernel_file": str,
"tensor_range": str,
}

@classmethod
Expand Down Expand Up @@ -477,6 +490,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: dict[str, Any]) -> Node:
tid,
kernel_backend,
kernel_file,
_,
) = ExecutionTrace._read_attrs(x)

comm_attrs = (
Expand Down Expand Up @@ -520,6 +534,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node:
tid,
kernel_backend,
kernel_file,
tensor_range,
) = ExecutionTrace._read_attrs(x)

comm_attrs = (
Expand Down Expand Up @@ -551,6 +566,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: dict[str, Any]) -> Node:
comm_attrs,
x["inputs"]["strides"],
x["outputs"]["strides"],
tensor_range,
)

def get_nodes(self, clean: bool = False):
Expand Down
34 changes: 32 additions & 2 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def add_comm_tensor_registry(tensor_strides, tensors):
shape,
dtype,
strides,
node.get_input_tensor_range(idx),
)
self.tensor_registry_permanent[replay_t_id] = tensor
except KeyError:
Expand Down Expand Up @@ -724,6 +725,7 @@ def allocate_comp_tensors(self, node): # noqa: C901
shape,
dtype,
strides,
node.get_input_tensor_range(idx),
)
self.tensor_registry_permanent[replay_t_id] = tensor

Expand Down Expand Up @@ -1178,11 +1180,40 @@ def get_tensor_from_storage(
shape,
data_type,
strides,
tensor_range,
):
assert storage_id in self.tensor_storage_map
tensor_data = self.tensor_storage_map[storage_id]
if device not in tensor_data[1]:
if data_type in [torch.half, torch.float32, torch.float64, torch.bfloat16]:
if (
data_type
in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.int,
torch.long,
]
and tensor_range is not None
):
storage_tensor = torch.randint(
tensor_range[0],
tensor_range[1] + 1,
(tensor_data[0] // elem_bytes,),
dtype=data_type,
device=device,
)
elif data_type in [
torch.half,
torch.float32,
torch.float64,
torch.bfloat16,
]:
storage_tensor = torch.rand(
(tensor_data[0] // elem_bytes), dtype=data_type, device=device
)
Expand Down Expand Up @@ -1369,7 +1400,6 @@ def run_op(self, node, iter, cnt): # noqa: C901
inputs, msg = self.get_data(node, True)
if msg != "":
return False, msg

# TODO: why need this hack?
# Workaround to eliminate the "strides() called on undefined Tensor" error.
if node.name == "aten::convolution_backward":
Expand Down

0 comments on commit 01ad990

Please sign in to comment.