Skip to content

Commit

Permalink
Log guard latency (#145132)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#145132
Approved by: https://github.com/ezyang
ghstack dependencies: #145509

Reviewed By: ZainRizvi

Differential Revision: D68685480

fbshipit-source-id: fe35b627407e32a580f78027562b092083043d99
  • Loading branch information
anijain2305 authored and facebook-github-bot committed Jan 27, 2025
1 parent 4b95e3a commit be908b3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
48 changes: 25 additions & 23 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,30 +181,32 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None:
instructions.insert(0, create_instruction("NOP"))
instructions.insert(0, create_instruction("NOP"))

if is_generator(frame.f_code):
return None

debug_checks(frame.f_code)
code = transform_code_object(frame.f_code, insert_nops)
graph = OutputGraph(
code_options={},
compiler_fn=None,
root_tx=None,
export=False,
export_constraints=None,
frame_state={"_id": 0},
# TODO: shouldn't this be f_locals/f_globals from frame?
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
metrics_context = torch._dynamo.utils.get_metrics_context()
with torch._dynamo.utils.dynamo_timed("debug_insert_nops"), metrics_context:
if is_generator(frame.f_code):
return None

debug_checks(frame.f_code)
code = transform_code_object(frame.f_code, insert_nops)
graph = OutputGraph(
code_options={},
compiler_fn=None,
root_tx=None,
export=False,
export_constraints=None,
frame_state={"_id": 0},
# TODO: shouldn't this be f_locals/f_globals from frame?
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)

return GuardedCode(
code,
CheckFunctionManager(frame.f_code, graph).guard_manager, # type: ignore[arg-type]
CompileId(frame_id=0, frame_compile_id=0),
)
return GuardedCode(
code,
CheckFunctionManager(frame.f_code, graph).guard_manager, # type: ignore[arg-type]
CompileId(frame_id=0, frame_compile_id=0),
)


class CompileCounter:
Expand Down
1 change: 1 addition & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ class CompilationMetrics:
tensorify_float_attempt: Optional[bool] = None
tensorify_float_success: Optional[bool] = None
tensorify_float_failure: Optional[set[str]] = None
guard_latency_us: Optional[float] = None

@classmethod
def create(cls, metrics: dict[str, Any]):
Expand Down

0 comments on commit be908b3

Please sign in to comment.