Skip to content

Commit

Permalink
inductor: Don't throw an internal error when a nn.module is missing a…
Browse files Browse the repository at this point in the history
… attribute (pytorch#145122)

If a nn.module getattr call throws, we should make sure that we don't crash with an internal error

Note that I couldn't figure out how to test this, so advice would be awesome.  I have my best case attempt at  pytorch#145799, but it doesn't seem to reproduce the crash.

Pull Request resolved: pytorch#145122
Approved by: https://github.com/jansel
  • Loading branch information
c00w authored and pytorchmergebot committed Feb 5, 2025
1 parent eb832b7 commit 93d98ac
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
15 changes: 15 additions & 0 deletions test/dynamo/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ def fn(x):
with self.assertRaises(AttributeError):
fn(x)

@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
def test_compilation_nn_module_invalid_method(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + self.doesnotexist

mod = Mod()
opt_mod = torch.compile(mod, backend="eager")
x = torch.randn(1, 1)
with self.assertRaises(AttributeError):
opt_mod(x)


# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good
Expand Down
1 change: 1 addition & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,7 @@ def run():
run()
self.assertTrue(models[0].abc)

@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
def test_assign_does_not_exist(self):
class MyModule(torch.nn.Module):
def forward(self, x):
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .. import config, polyfills, variables
from ..exc import (
AttributeMutationError,
ObservedAttributeError,
unimplemented,
Unsupported,
UserError,
Expand Down Expand Up @@ -1907,7 +1908,7 @@ def _lower_version_count_by_1(x):

try:
getattr_var = obj.var_getattr(tx, name_var.as_python_constant())
except AttributeError:
except (AttributeError, ObservedAttributeError):
getattr_var = None

if isinstance(getattr_var, variables.TensorVariable):
Expand Down
7 changes: 5 additions & 2 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
)
if result is not None:
return result
# if we can't find a __getattr__, just raise the AttributeError
raise
# if we can't find a __getattr__, we can't parse this, raise attribute error
raise_observed_exception(
AttributeError,
tx,
)

if name == "forward":
guard_to_detect_forward_monkeypatching(self.source, base)
Expand Down

0 comments on commit 93d98ac

Please sign in to comment.