-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add verify_backward to enable testing bacward ops #1383
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,11 +12,18 @@ | |
from loguru import logger | ||
from forge.forgeglobal import align_up_tile | ||
import torch | ||
from torch import nn | ||
import tensorflow as tf | ||
from forge.tensor import to_pt_tensors | ||
from forge.tensor import to_pt_tensors, TensorFromTrace | ||
|
||
from ..tensor import Tensor, TensorShape, pad_pytorch_tensor_to_forge, narrow_forge_tensor_to_pytorch | ||
from .config import DepricatedVerifyConfig, VerifyConfig, VerifyTensorMetadata, should_waive_gradient | ||
from .config import ( | ||
DepricatedVerifyConfig, | ||
VerifyConfig, | ||
VerifyTensorMetadata, | ||
should_waive_gradient, | ||
AutomaticValueChecker, | ||
) | ||
import forge._C.graph as pygraph | ||
from forge.tools.run_net2pipe import net2pipe | ||
from forge.compiled_graph_state import CompiledModel | ||
|
@@ -42,6 +49,11 @@ def _generate_random_losses(outputs, is_forge): | |
return losses | ||
|
||
|
||
def _squeeze_tensor(t: torch.Tensor): | ||
t = t.squeeze() | ||
return t | ||
|
||
|
||
def _run_pytorch_backward(outputs, device, losses): | ||
retain_graph = True | ||
for i, o in enumerate(outputs): | ||
|
@@ -263,6 +275,133 @@ def verify_golden( | |
assert False # Run ttnn golden | ||
|
||
|
||
def verify_backward( | ||
inputs: List[torch.Tensor], | ||
output_grad: torch.Tensor, | ||
framework_output: torch.Tensor, | ||
compiled_output: torch.Tensor, | ||
framework_model: torch.nn.Module, | ||
compiled_model: CompiledModel, | ||
original_model: torch.nn.Module = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need |
||
verify_cfg: VerifyConfig = VerifyConfig(), | ||
): | ||
""" | ||
Performs verification of a compiled model by comparing its outputs against a reference framework model. | ||
|
||
Runs backward on both models with the same inputs and performs various validation checks | ||
based on the provided verification configuration. Checks can include output size matching, | ||
dtype consistency, shape equivalence, and numeric value comparison. | ||
|
||
Parameters: | ||
inputs: List of tensor inputs | ||
output_grad: Output gradient tensor | ||
framework_output: Output tensor from the reference framework model | ||
compiled_output: Output tensor from the compiled model | ||
framework_model: Reference model | ||
compiled_model: compiled model to verify | ||
verify_cfg: Configuration object controlling which verification checks to perform | ||
""" | ||
|
||
if not verify_cfg.enabled: | ||
logger.warning("Verification is disabled") | ||
return | ||
|
||
assert compiled_model.training(), "Compiled model must be in compiled for training for backward verification" | ||
|
||
# Check if inputs are of the correct type | ||
if not inputs: | ||
raise ValueError("Input tensors must be provided") | ||
|
||
req_grad = False | ||
for input_tensor in inputs: | ||
if not isinstance(input_tensor, torch.Tensor): | ||
raise TypeError(f"Input tensor must be of type {torch.Tensor}, but got {type(input_tensor)}") | ||
req_grad |= input_tensor.requires_grad | ||
|
||
if not req_grad: | ||
raise ValueError("One of the input tensors must require gradient") | ||
|
||
if not isinstance(output_grad, torch.Tensor): | ||
raise TypeError(f"Output gradient tensor must be of type {torch.Tensor}, but got {type(output_grad)}") | ||
|
||
if not isinstance(framework_output, torch.Tensor): | ||
raise TypeError(f"Framework output tensor must be of type {torch.Tensor}, but got {type(framework_output)}") | ||
if not isinstance(compiled_output, torch.Tensor): | ||
raise TypeError(f"Compiled output tensor must be of type {torch.Tensor}, but got {type(compiled_output)}") | ||
|
||
if not isinstance(framework_model, torch.nn.Module): | ||
raise TypeError(f"Framework model must be of type {torch.nn.Module}, but got {type(framework_model)}") | ||
Comment on lines
+332
to
+333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for now we just support backward verification for |
||
|
||
if not isinstance(compiled_model, verify_cfg.compiled_model_types): | ||
raise TypeError( | ||
f"Compiled model must be of type {verify_cfg.compiled_model_types}, but got {type(compiled_model)}" | ||
) | ||
|
||
if original_model is None: | ||
original_model = framework_model | ||
elif not isinstance(original_model, torch.nn.Module): | ||
raise TypeError(f"Original model must be of type {torch.nn.Module}, but got {type(original_model)}") | ||
|
||
# 1st step: run backward pass for the networks and get gradients | ||
[input.grad.zero_() for input in inputs if input.grad is not None] | ||
original_model.zero_grad() | ||
compiled_model.gradient_inputs = [output_grad] | ||
co_gradient_outputs = compiled_model.backward() | ||
co_grads = [ | ||
grad | ||
for grad, name in zip(co_gradient_outputs, compiled_model.bwd_compiled_graph_state.ordered_output_names) | ||
if not name.startswith("grad_acc_") # This is added to gradient of the parameter node | ||
] | ||
# HACK: This will fail if the first argument is not used first, second argument is not used second, etc. | ||
co_grads = list(reversed(co_grads)) | ||
co_grads += [param.grad.clone() for param in original_model.parameters() if param.requires_grad] | ||
|
||
# Run backward on framework model | ||
[input.grad.zero_() for input in inputs if input.grad is not None] | ||
framework_model.zero_grad() | ||
framework_output.backward(gradient=output_grad) | ||
fw_grads = [input.grad.clone() for input in inputs if input.requires_grad] | ||
fw_grads += [param.grad for param in framework_model.parameters() if param.requires_grad] | ||
|
||
# 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary, | ||
# cast from tensorflow tensors to pytorch tensors if needed) | ||
fw_grads = to_pt_tensors(fw_grads) | ||
co_grads = [co.to("cpu") for co in co_grads] | ||
|
||
assert all(isinstance(co, torch.Tensor) for co in co_grads), f"Compiled model output is not a list of torch.Tensor" | ||
|
||
# 3rd step: verifications of outputs | ||
# - size check | ||
# - dtype check | ||
# - shape check | ||
# - compare with golden | ||
if verify_cfg.verify_size: | ||
if len(fw_grads) != len(co_grads): | ||
raise ValueError( | ||
f"Number of gradients from framework model and compiled model do not match: framework model has {len(fw_grads)} outputs, compiled model has {len(co_grads)} outputs" | ||
) | ||
|
||
for fw, co in zip(fw_grads, co_grads): | ||
# Squeeze the tensors to remove any extra dimensions | ||
# NOTE: realized that there is narrow_forge_tensor_to_pytorch that should do the same but seams to be unused | ||
fw = _squeeze_tensor(fw) | ||
co = _squeeze_tensor(co) | ||
Comment on lines
+387
to
+388
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just use regular |
||
|
||
if verify_cfg.verify_dtype: | ||
if fw.dtype != co.dtype: | ||
raise TypeError(f"Dtype mismatch: framework_model.dtype={fw.dtype}, compiled_model.dtype={co.dtype}") | ||
|
||
if verify_cfg.verify_shape: | ||
if fw.shape != co.shape: | ||
if fw.shape != co.shape: | ||
raise TypeError( | ||
f"Shape mismatch: framework_model.shape={fw.shape}, compiled_model.shape={co.shape}" | ||
) | ||
|
||
if verify_cfg.verify_values: | ||
verify_cfg.value_checker.check(fw, co) | ||
|
||
|
||
def verify( | ||
inputs: List[Union[torch.Tensor, tf.Tensor, tf.Variable]], | ||
framework_model: Union[torch.nn.Module, tf.Module, tf.keras.Model], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import forge | ||
from forge.verify.config import VerifyConfig | ||
from forge.verify.value_checkers import AutomaticValueChecker | ||
from forge.verify.verify import verify | ||
from forge.verify.verify import verify, verify_backward | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -128,21 +128,30 @@ def forward(self, x): | |
verify(inputs, framework_model, compiled_model) | ||
|
||
|
||
# @pytest.mark.parametrize("train", [False, True]) | ||
@pytest.mark.parametrize( | ||
"shapes", | ||
"shapes, train", | ||
[ | ||
((11, 2048), (2048, 2048)), | ||
((1, 32, 1), (1, 1, 11)), | ||
((11, 2048), (2048, 512)), | ||
((32, 11, 64), (32, 64, 11)), | ||
((32, 11, 11), (32, 11, 64)), | ||
((11, 2048), (2048, 8192)), | ||
((1, 11, 8192), (8192, 2048)), | ||
((1, 11, 2048), (2048, 128256)), | ||
(((11, 2048), (2048, 2048)), False), | ||
(((1, 32, 1), (1, 1, 11)), False), | ||
(((11, 2048), (2048, 512)), False), | ||
(((32, 11, 64), (32, 64, 11)), False), | ||
(((32, 11, 11), (32, 11, 64)), False), | ||
(((11, 2048), (2048, 8192)), False), | ||
(((1, 11, 8192), (8192, 2048)), False), | ||
(((1, 11, 2048), (2048, 128256)), False), | ||
(((11, 2048), (2048, 2048)), True), | ||
(((1, 32, 1), (1, 1, 11)), True), | ||
(((11, 2048), (2048, 512)), True), | ||
(((32, 11, 64), (32, 64, 11)), True), | ||
(((32, 11, 11), (32, 11, 64)), True), | ||
(((11, 2048), (2048, 8192)), True), | ||
(((1, 11, 8192), (8192, 2048)), True), | ||
pytest.param(((1, 11, 2048), (2048, 128256)), True, marks=pytest.mark.xfail(reason="Low PCC")), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_matmul(shapes): | ||
def test_matmul(shapes, train): | ||
shape1, shape2 = shapes | ||
|
||
class Matmul(nn.Module): | ||
|
@@ -153,14 +162,36 @@ def forward(self, x, y): | |
return torch.matmul(x, y) | ||
|
||
inputs = [ | ||
torch.rand(shape1), | ||
torch.rand(shape2), | ||
torch.rand(shape1, requires_grad=train), | ||
torch.rand(shape2, requires_grad=train), | ||
] | ||
|
||
framework_model = Matmul() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95))) | ||
framework_model.eval() if not train else framework_model.train() | ||
|
||
# NOTE: We probably need two framework models with the same state_dict to compare the outputs | ||
# But for now it works without that for some reason? | ||
# model_for_compile = Matmul() | ||
# model_for_compile.eval() if not training else model_for_compile.train() | ||
# model_for_compile.load_state_dict(framework_model.state_dict()) | ||
Comment on lines
+172
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline, we probably don't need 2 framework models. |
||
compiled_model = forge.compile(framework_model, sample_inputs=inputs, training=train) | ||
|
||
fw_out, co_out = verify( | ||
inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)) | ||
) | ||
if train: | ||
# Simulate the backward pass of the loss | ||
output_grad = torch.rand_like(fw_out[0]) | ||
|
||
verify_backward( | ||
inputs, | ||
output_grad, | ||
fw_out[0], | ||
co_out[0], | ||
framework_model, | ||
compiled_model, | ||
verify_cfg=VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just remove
import AutomaticValueChecker
if we don't use it here?