diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index e857596f..ee35830a 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -12,7 +12,7 @@ import numpy as np import torch -from param_bench.et_replay.lib.et_replay_utils import ( +from ..lib.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, build_triton_func, @@ -36,7 +36,7 @@ TORCH_DTYPES_RNG_str, ) -from param_bench.et_replay.lib.execution_trace import ExecutionTrace, NodeType +from et_replay.lib.execution_trace import ExecutionTrace, NodeType from param_bench.et_replay.lib.utils import trace_handler from param_bench.train.comms.pt import comms_utils, commsTraceReplay diff --git a/et_replay/tools/validate_trace.py b/et_replay/tools/validate_trace.py index f1b44b01..46af10bb 100644 --- a/et_replay/tools/validate_trace.py +++ b/et_replay/tools/validate_trace.py @@ -9,7 +9,7 @@ import gzip import json -from param_bench.et_replay.lib.execution_trace import ExecutionTrace +from et_replay.lib.execution_trace import ExecutionTrace class TraceValidator: