Skip to content

Commit

Permalink
(fix)(Torch Frontend)(module.py): made a couple of fixes to the `name…
Browse files Browse the repository at this point in the history
…d_modules and `apply` methods to handle native keras classes as part of submodules/children.
  • Loading branch information
YushaArif99 committed May 24, 2024
1 parent f5f662d commit e443886
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions ivy/functional/frontends/torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def add_module(self, name: str, module: Optional["Module"]) -> None:

def apply(self, fn: Callable[["Module"], None]):
for module in self.children():
module.apply(fn)
if hasattr(module, "apply"):
module.apply(fn)
else:
fn(module)
fn(self)
return self

Expand Down Expand Up @@ -262,9 +265,12 @@ def named_modules(
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from module.named_modules(
memo, submodule_prefix, remove_duplicate
)
if not hasattr(module, "named_modules"):
yield submodule_prefix, self
else:
yield from module.named_modules(
memo, submodule_prefix, remove_duplicate
)

def requires_grad_(self, requires_grad: bool = True):
for p in self.parameters():
Expand Down

0 comments on commit e443886

Please sign in to comment.