Skip to content

Commit

Permalink
Draft refactor of et replay (#110)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/HolisticTraceAnalysis#131

Pull Request resolved: #110

Add a new tree structure for et_replay to enable better encapsulation of this code. We keep comms and compute unchanged, only moving et_replay files.

```
[bcoutinho@devgpu038.ftw6 ~/fbsource/fbcode/param_bench/et_replay (d4b11e786)]$ tree .
.
├── lib
│   ├── et_replay_utils.py
│   ├── execution_trace.py
│   └── utils.py
├── README.md
├── tests
│   ├── inputs
│   │   ├── 1.0.3-chakra.0.0.4
│   │   │   └── resnet_1gpu_et.json.gz
│   │   ├── 1.1.0-chakra.0.0.4
│   │   │   └── resnet_2gpu_et.json.gz
│   │   ├── dlrm_kineto.tar.gz
│   │   ├── dlrm_pytorch_et.tar.gz
│   │   ├── __init__.py
│   │   ├── linear_et.json.gz
│   │   ├── linear_kineto.json.gz
│   │   ├── resnet_et.json.gz
│   │   └── resnet_kineto.json.gz
│   └── test_execution_trace.py
└── tools
    ├── et_replay.py
    └── validate_trace.py
```

Reviewed By: shengfukevin

Differential Revision: D56960365

fbshipit-source-id: d2ef172bc6c4629d78222357e616df9bddaec81e
  • Loading branch information
briancoutinho authored and facebook-github-bot committed May 7, 2024
1 parent 2b4cf3e commit 9b1946f
Show file tree
Hide file tree
Showing 16 changed files with 18 additions and 28 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode
from param_bench.et_replay.lib.execution_trace import NodeType

from param_bench.train.compute.python.lib.pytorch.config_util import create_op_args
from param_bench.train.compute.python.tools.execution_trace import NodeType

from param_bench.train.compute.python.workloads.pytorch.split_table_batched_embeddings_ops import (
SplitTableBatchedEmbeddingBagsCodegenInputDataGenerator,
Expand Down Expand Up @@ -473,7 +473,7 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
import torch
from param_bench.train.comms.pt import commsTraceReplay
from param_bench.train.compute.python.tools.et_replay_utils import (
from param_bench.et_replay.lib.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
generate_fbgemm_tensors,
Expand All @@ -482,8 +482,8 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
is_qualified,
)
from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from param_bench.train.compute.python.tools.utility import trace_handler
from param_bench.et_replay.lib.execution_trace import ExecutionTrace
from param_bench.et_replay.lib.utils import trace_handler
print("PyTorch version: ", torch.__version__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,7 @@ def main():
execution_data: TextIO
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
execution_trace.set_iterations(args.step_annotation)
# nocommit remove
execution_trace = execution_trace.clone_one_iteration(2)
# execution_trace = execution_trace.clone_one_iteration(2)

if args.list_op:
execution_trace.print_op_stats(args.detail, args.json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from typing import Any, Dict

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from param_bench.et_replay.lib.execution_trace import ExecutionTrace


def get_tmp_trace_filename() -> str:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import numpy as np
import torch
from param_bench.train.comms.pt import comms_utils, commsTraceReplay

from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.tools.et_replay_utils import (
from param_bench.et_replay.lib.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
build_triton_func,
Expand All @@ -40,12 +36,13 @@
TORCH_DTYPES_RNG_str,
)

from param_bench.train.compute.python.tools.execution_trace import (
ExecutionTrace,
NodeType,
)
from param_bench.et_replay.lib.execution_trace import ExecutionTrace, NodeType

from param_bench.et_replay.lib.utils import trace_handler
from param_bench.train.comms.pt import comms_utils, commsTraceReplay

from param_bench.train.compute.python.tools.utility import trace_handler
from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.codecache import AsyncCompile, TritonFuture

Expand Down Expand Up @@ -129,7 +126,7 @@ def __init__(self):
self.label = ""

try:
from param_bench.train.compute.python.tools.fb.internals import (
from param_bench.et_replay.lib.fb.internals import (
add_internal_label,
add_internal_parallel_nodes_parents,
add_internal_skip_nodes,
Expand Down Expand Up @@ -212,9 +209,7 @@ def initBench(self):
# Input et trace should be explicitly specified after --input.
if "://" in self.args.input:
try:
from param_bench.train.compute.python.tools.fb.internals import (
read_remote_trace,
)
from param_bench.et_replay.lib.fb.internals import read_remote_trace
except ImportError:
logging.info("FB internals not present")
exit(1)
Expand All @@ -239,9 +234,7 @@ def initBench(self):
# Different processes should read different traces based on global_rank_id.
if "://" in self.args.trace_path:
try:
from param_bench.train.compute.python.tools.fb.internals import (
read_remote_trace,
)
from param_bench.et_replay.lib.fb.internals import read_remote_trace
except ImportError:
logging.info("FB internals not present")
exit(1)
Expand Down Expand Up @@ -1507,9 +1500,7 @@ def benchTime(self):
end_time = datetime.now()

try:
from param_bench.train.compute.python.tools.fb.internals import (
generate_query_url,
)
from param_bench.et_replay.lib.fb.internals import generate_query_url
except ImportError:
logging.info("FB internals not present")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import gzip
import json

from .execution_trace import ExecutionTrace
from param_bench.et_replay.lib.execution_trace import ExecutionTrace


class TraceValidator:
Expand Down

0 comments on commit 9b1946f

Please sign in to comment.