Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verify_backward to enable testing bacward ops #1383

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 141 additions & 2 deletions forge/forge/verify/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
from loguru import logger
from forge.forgeglobal import align_up_tile
import torch
from torch import nn
import tensorflow as tf
from forge.tensor import to_pt_tensors
from forge.tensor import to_pt_tensors, TensorFromTrace

from ..tensor import Tensor, TensorShape, pad_pytorch_tensor_to_forge, narrow_forge_tensor_to_pytorch
from .config import DepricatedVerifyConfig, VerifyConfig, VerifyTensorMetadata, should_waive_gradient
from .config import (
DepricatedVerifyConfig,
VerifyConfig,
VerifyTensorMetadata,
should_waive_gradient,
AutomaticValueChecker,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just remove import AutomaticValueChecker if we don't use it here?

)
import forge._C.graph as pygraph
from forge.tools.run_net2pipe import net2pipe
from forge.compiled_graph_state import CompiledModel
Expand All @@ -42,6 +49,11 @@ def _generate_random_losses(outputs, is_forge):
return losses


def _squeeze_tensor(t: torch.Tensor):
t = t.squeeze()
return t


def _run_pytorch_backward(outputs, device, losses):
retain_graph = True
for i, o in enumerate(outputs):
Expand Down Expand Up @@ -263,6 +275,133 @@ def verify_golden(
assert False # Run ttnn golden


def verify_backward(
inputs: List[torch.Tensor],
output_grad: torch.Tensor,
framework_output: torch.Tensor,
compiled_output: torch.Tensor,
framework_model: torch.nn.Module,
compiled_model: CompiledModel,
original_model: torch.nn.Module = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need original_model?

verify_cfg: VerifyConfig = VerifyConfig(),
):
"""
Performs verification of a compiled model by comparing its outputs against a reference framework model.

Runs backward on both models with the same inputs and performs various validation checks
based on the provided verification configuration. Checks can include output size matching,
dtype consistency, shape equivalence, and numeric value comparison.

Parameters:
inputs: List of tensor inputs
output_grad: Output gradient tensor
framework_output: Output tensor from the reference framework model
compiled_output: Output tensor from the compiled model
framework_model: Reference model
compiled_model: compiled model to verify
verify_cfg: Configuration object controlling which verification checks to perform
"""

if not verify_cfg.enabled:
logger.warning("Verification is disabled")
return

assert compiled_model.training(), "Compiled model must be in compiled for training for backward verification"

# Check if inputs are of the correct type
if not inputs:
raise ValueError("Input tensors must be provided")

req_grad = False
for input_tensor in inputs:
if not isinstance(input_tensor, torch.Tensor):
raise TypeError(f"Input tensor must be of type {torch.Tensor}, but got {type(input_tensor)}")
req_grad |= input_tensor.requires_grad

if not req_grad:
raise ValueError("One of the input tensors must require gradient")

if not isinstance(output_grad, torch.Tensor):
raise TypeError(f"Output gradient tensor must be of type {torch.Tensor}, but got {type(output_grad)}")

if not isinstance(framework_output, torch.Tensor):
raise TypeError(f"Framework output tensor must be of type {torch.Tensor}, but got {type(framework_output)}")
if not isinstance(compiled_output, torch.Tensor):
raise TypeError(f"Compiled output tensor must be of type {torch.Tensor}, but got {type(compiled_output)}")

if not isinstance(framework_model, torch.nn.Module):
raise TypeError(f"Framework model must be of type {torch.nn.Module}, but got {type(framework_model)}")
Comment on lines +332 to +333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now we just support backward verification for torch?


if not isinstance(compiled_model, verify_cfg.compiled_model_types):
raise TypeError(
f"Compiled model must be of type {verify_cfg.compiled_model_types}, but got {type(compiled_model)}"
)

if original_model is None:
original_model = framework_model
elif not isinstance(original_model, torch.nn.Module):
raise TypeError(f"Original model must be of type {torch.nn.Module}, but got {type(original_model)}")

# 1st step: run backward pass for the networks and get gradients
[input.grad.zero_() for input in inputs if input.grad is not None]
original_model.zero_grad()
compiled_model.gradient_inputs = [output_grad]
co_gradient_outputs = compiled_model.backward()
co_grads = [
grad
for grad, name in zip(co_gradient_outputs, compiled_model.bwd_compiled_graph_state.ordered_output_names)
if not name.startswith("grad_acc_") # This is added to gradient of the parameter node
]
# HACK: This will fail if the first argument is not used first, second argument is not used second, etc.
co_grads = list(reversed(co_grads))
co_grads += [param.grad.clone() for param in original_model.parameters() if param.requires_grad]

# Run backward on framework model
[input.grad.zero_() for input in inputs if input.grad is not None]
framework_model.zero_grad()
framework_output.backward(gradient=output_grad)
fw_grads = [input.grad.clone() for input in inputs if input.requires_grad]
fw_grads += [param.grad for param in framework_model.parameters() if param.requires_grad]

# 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary,
# cast from tensorflow tensors to pytorch tensors if needed)
fw_grads = to_pt_tensors(fw_grads)
co_grads = [co.to("cpu") for co in co_grads]

assert all(isinstance(co, torch.Tensor) for co in co_grads), f"Compiled model output is not a list of torch.Tensor"

# 3rd step: verifications of outputs
# - size check
# - dtype check
# - shape check
# - compare with golden
if verify_cfg.verify_size:
if len(fw_grads) != len(co_grads):
raise ValueError(
f"Number of gradients from framework model and compiled model do not match: framework model has {len(fw_grads)} outputs, compiled model has {len(co_grads)} outputs"
)

for fw, co in zip(fw_grads, co_grads):
# Squeeze the tensors to remove any extra dimensions
# NOTE: realized that there is narrow_forge_tensor_to_pytorch that should do the same but seams to be unused
fw = _squeeze_tensor(fw)
co = _squeeze_tensor(co)
Comment on lines +387 to +388
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just use regular squeeze instead?


if verify_cfg.verify_dtype:
if fw.dtype != co.dtype:
raise TypeError(f"Dtype mismatch: framework_model.dtype={fw.dtype}, compiled_model.dtype={co.dtype}")

if verify_cfg.verify_shape:
if fw.shape != co.shape:
if fw.shape != co.shape:
raise TypeError(
f"Shape mismatch: framework_model.shape={fw.shape}, compiled_model.shape={co.shape}"
)

if verify_cfg.verify_values:
verify_cfg.value_checker.check(fw, co)


def verify(
inputs: List[Union[torch.Tensor, tf.Tensor, tf.Variable]],
framework_model: Union[torch.nn.Module, tf.Module, tf.keras.Model],
Expand Down
63 changes: 47 additions & 16 deletions forge/test/mlir/llama/tests/test_specific_ops_llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import forge
from forge.verify.config import VerifyConfig
from forge.verify.value_checkers import AutomaticValueChecker
from forge.verify.verify import verify
from forge.verify.verify import verify, verify_backward


@pytest.mark.parametrize(
Expand Down Expand Up @@ -128,21 +128,30 @@ def forward(self, x):
verify(inputs, framework_model, compiled_model)


# @pytest.mark.parametrize("train", [False, True])
@pytest.mark.parametrize(
"shapes",
"shapes, train",
[
((11, 2048), (2048, 2048)),
((1, 32, 1), (1, 1, 11)),
((11, 2048), (2048, 512)),
((32, 11, 64), (32, 64, 11)),
((32, 11, 11), (32, 11, 64)),
((11, 2048), (2048, 8192)),
((1, 11, 8192), (8192, 2048)),
((1, 11, 2048), (2048, 128256)),
(((11, 2048), (2048, 2048)), False),
(((1, 32, 1), (1, 1, 11)), False),
(((11, 2048), (2048, 512)), False),
(((32, 11, 64), (32, 64, 11)), False),
(((32, 11, 11), (32, 11, 64)), False),
(((11, 2048), (2048, 8192)), False),
(((1, 11, 8192), (8192, 2048)), False),
(((1, 11, 2048), (2048, 128256)), False),
(((11, 2048), (2048, 2048)), True),
(((1, 32, 1), (1, 1, 11)), True),
(((11, 2048), (2048, 512)), True),
(((32, 11, 64), (32, 64, 11)), True),
(((32, 11, 11), (32, 11, 64)), True),
(((11, 2048), (2048, 8192)), True),
(((1, 11, 8192), (8192, 2048)), True),
pytest.param(((1, 11, 2048), (2048, 128256)), True, marks=pytest.mark.xfail(reason="Low PCC")),
],
)
@pytest.mark.push
def test_matmul(shapes):
def test_matmul(shapes, train):
shape1, shape2 = shapes

class Matmul(nn.Module):
Expand All @@ -153,14 +162,36 @@ def forward(self, x, y):
return torch.matmul(x, y)

inputs = [
torch.rand(shape1),
torch.rand(shape2),
torch.rand(shape1, requires_grad=train),
torch.rand(shape2, requires_grad=train),
]

framework_model = Matmul()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))
framework_model.eval() if not train else framework_model.train()

# NOTE: We probably need two framework models with the same state_dict to compare the outputs
# But for now it works without that for some reason?
# model_for_compile = Matmul()
# model_for_compile.eval() if not training else model_for_compile.train()
# model_for_compile.load_state_dict(framework_model.state_dict())
Comment on lines +172 to +176
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we probably don't need 2 framework models.

compiled_model = forge.compile(framework_model, sample_inputs=inputs, training=train)

fw_out, co_out = verify(
inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95))
)
if train:
# Simulate the backward pass of the loss
output_grad = torch.rand_like(fw_out[0])

verify_backward(
inputs,
output_grad,
fw_out[0],
co_out[0],
framework_model,
compiled_model,
verify_cfg=VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)),
)


@pytest.mark.parametrize(
Expand Down
Loading