diff --git a/ivy/functional/backends/jax/__init__.py b/ivy/functional/backends/jax/__init__.py index 8468027de6618..bb7b9df6b381c 100644 --- a/ivy/functional/backends/jax/__init__.py +++ b/ivy/functional/backends/jax/__init__.py @@ -238,7 +238,7 @@ def closest_valid_dtype(type=None, /, as_native=False): from . import control_flow_ops from .control_flow_ops import * from . import module -from .module import Model +from .module import Module # sub-backends @@ -249,7 +249,7 @@ def closest_valid_dtype(type=None, /, as_native=False): if importlib.util.find_spec("flax"): import flax - NativeModule = Model + NativeModule = Module elif importlib.util.find_spec("haiku"): import haiku as hk