Skip to content

Commit

Permalink
Fix the execution phase and stage recording issue and replace compare…
Browse files Browse the repository at this point in the history
…_with_golden with verify function
  • Loading branch information
chandrasekaranpradeep committed Mar 7, 2025
1 parent 754b717 commit 898dd24
Show file tree
Hide file tree
Showing 65 changed files with 255 additions and 282 deletions.
3 changes: 0 additions & 3 deletions forge/forge/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from forge._C import run_mlir_compiler_to_cpp
from forge.tensor import Tensor, get_post_const_eval_tensors, to_pt_tensors, cast_unsupported_torch_dtype, AnyTensor
from forge.module import Module, PyTorchModule, AnyModule
from forge.execution_tracker import ExecutionPhase, record_execution_phase_and_stage


class CompileResults:
Expand Down Expand Up @@ -285,8 +284,6 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]:
# would capture the idx by reference, and all the lambdas would have the same idx value.
output.register_hook(lambda grad, idx=idx: self.tie_grad_fn(idx, grad))

record_execution_phase_and_stage(ExecutionPhase.EXECUTED_TTNN)

return model_outputs

def forward(self, *inputs: AnyTensor) -> List[torch.Tensor]:
Expand Down
8 changes: 1 addition & 7 deletions forge/forge/execution_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,7 @@ class ExecutionStage(Enum):

# Stage under PASSED

# 1. Compared the golden and framework output (compare_with_golden in forge/forge/verify/compare.py).
COMPARE_WITH_GOLDEN = (
ExecutionPhase.PASSED,
"COMPARE_WITH_GOLDEN",
)

# 2. Performed verification (verify function in forge/forge/verify/verify.py).
# 1. Performed verification (verify function in forge/forge/verify/verify.py).
VERIFICATON = (ExecutionPhase.PASSED, "VERIFICATON")

def __init__(self, phase, stage_name):
Expand Down
1 change: 0 additions & 1 deletion forge/forge/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def write_header(self, include_pytest_imports=False):
if include_pytest_imports:
self.wl("")
self.wl("from forge import Tensor, compile")
self.wl("from forge.verify.compare import compare_with_golden")
self.wl("from forge.verify.verify import verify")
self.wl("from forge.verify.value_checkers import AutomaticValueChecker")
self.wl("from forge.verify.config import VerifyConfig")
Expand Down
8 changes: 0 additions & 8 deletions forge/forge/verify/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from scipy.spatial import distance

from forge.tensor import narrow_forge_tensor_to_pytorch
from forge.execution_tracker import ExecutionPhase, ExecutionStage, record_execution_phase_and_stage

# Compares golden and calculated tensors. Using allclose for scalar values, rogerstanimoto for bool tensors, pcc otherwise
def compare_with_golden(
Expand Down Expand Up @@ -45,13 +44,6 @@ def compare_with_golden(

result = all_close

# The verify function (in forge/forge/verify/verify.py) calls compare_with_golden for each output.
# If any call returns False, it signals a tensor mismatch, so we revert to the previous execution phase (EXECUTED_TTNN)
if result:
record_execution_phase_and_stage(ExecutionStage.COMPARE_WITH_GOLDEN)
else:
record_execution_phase_and_stage(ExecutionPhase.EXECUTED_TTNN)

return result


Expand Down
4 changes: 3 additions & 1 deletion forge/forge/verify/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from forge.tools.run_net2pipe import net2pipe
from forge.compiled_graph_state import CompiledModel
from forge.verify.compare import compare_tensor_to_golden
from forge.execution_tracker import ExecutionStage, record_execution_phase_and_stage
from forge.execution_tracker import ExecutionPhase, ExecutionStage, record_execution_phase_and_stage


def _generate_random_losses(outputs, is_forge):
Expand Down Expand Up @@ -310,7 +310,9 @@ def verify(
# 1st step: run forward pass for the networks
fw_out = framework_model(*inputs)

record_execution_phase_and_stage(ExecutionPhase.COMPILE_MLIR)
co_out = compiled_model(*inputs)
record_execution_phase_and_stage(ExecutionPhase.EXECUTED_TTNN)

# 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary,
# cast from tensorflow tensors to pytorch tensors if needed)
Expand Down
5 changes: 2 additions & 3 deletions forge/test/benchmark/benchmark/models/mnist_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn
import forge
from forge.verify.compare import compare_with_golden_pcc
from forge.verify.verify import verify

# Common constants
GIT_REPO_NAME = "tenstorrent/tt-forge-fe"
Expand Down Expand Up @@ -108,8 +108,7 @@ def test_mnist_linear(
co_out = compiled_model(*inputs)
end = time.time()

co_out = [co.to("cpu") for co in co_out]
assert [compare_with_golden_pcc(golden=fo, calculated=co) for fo, co in zip(fw_out, co_out)]
verify(inputs, framework_model, compiled_model)

short_hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
date = (
Expand Down
Loading

0 comments on commit 898dd24

Please sign in to comment.