-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add verify_backward to enable testing bacward ops #1383
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1383 +/- ##
=======================================
Coverage ? 43.40%
=======================================
Files ? 48
Lines ? 7860
Branches ? 0
=======================================
Hits ? 3412
Misses ? 4448
Partials ? 0 ☔ View full report in Codecov by Sentry. |
|
|
|
|
if not isinstance(framework_model, torch.nn.Module): | ||
raise TypeError(f"Framework model must be of type {torch.nn.Module}, but got {type(framework_model)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for now we just support backward verification for torch
?
compiled_output: torch.Tensor, | ||
framework_model: torch.nn.Module, | ||
compiled_model: CompiledModel, | ||
original_model: torch.nn.Module = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need original_model
?
fw = _squeeze_tensor(fw) | ||
co = _squeeze_tensor(co) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just use regular squeeze
instead?
VerifyConfig, | ||
VerifyTensorMetadata, | ||
should_waive_gradient, | ||
AutomaticValueChecker, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just remove import AutomaticValueChecker
if we don't use it here?
# NOTE: We probably need two framework models with the same state_dict to compare the outputs | ||
# But for now it works without that for some reason? | ||
# model_for_compile = Matmul() | ||
# model_for_compile.eval() if not training else model_for_compile.train() | ||
# model_for_compile.load_state_dict(framework_model.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, we probably don't need 2 framework models.
Ticket
Fixes #1356
Problem description
There was no way to in one place verify gradients from backward pass
What's changed
Added verify_backward:
Checklist