diff --git a/ivy/functional/frontends/torch/nn/modules/module.py b/ivy/functional/frontends/torch/nn/modules/module.py index fe86e908088c0..b18a2f212b021 100644 --- a/ivy/functional/frontends/torch/nn/modules/module.py +++ b/ivy/functional/frontends/torch/nn/modules/module.py @@ -129,7 +129,10 @@ def add_module(self, name: str, module: Optional["Module"]) -> None: def apply(self, fn: Callable[["Module"], None]): for module in self.children(): - module.apply(fn) + if hasattr(module, "apply"): + module.apply(fn) + else: + fn(module) fn(self) return self @@ -262,9 +265,12 @@ def named_modules( if module is None: continue submodule_prefix = prefix + ("." if prefix else "") + name - yield from module.named_modules( - memo, submodule_prefix, remove_duplicate - ) + if not hasattr(module, "named_modules"): + yield submodule_prefix, self + else: + yield from module.named_modules( + memo, submodule_prefix, remove_duplicate + ) def requires_grad_(self, requires_grad: bool = True): for p in self.parameters():