diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 848399c..68833a3 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -21,20 +21,7 @@ def package(*names: str): _SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names) -def _check_mod(name, package=None): - module = import_plugin(name, package) - if not module: - raise ModuleNotFoundError(f"module {name!r} not found") - if hasattr(module, "__plugin__"): - if not package: - if name != module.__plugin__.id: - service._referents[name].add(module.__plugin__.id) - return module.__plugin__.proxy() - return module.__plugin__.subproxy(f"{package}{name}") - return module - - -def _check_import(name: str, plugin_name: str): +def __entari_import__(name: str, plugin_name: str, ensure_plugin: bool = False): if name in service.plugins: plug = service.plugins[name] if plugin_name != plug.id: @@ -46,6 +33,17 @@ def _check_import(name: str, plugin_name: str): if plugin_name != mod.__plugin__.id: service._referents[mod.__plugin__.id].add(plugin_name) return mod.__plugin__.subproxy(name) + if ensure_plugin: + module = import_plugin(name, plugin_name) + if not module: + raise ModuleNotFoundError(f"module {name!r} not found") + if hasattr(module, "__plugin__"): + if not plugin_name: + if name != module.__plugin__.id: + service._referents[name].add(module.__plugin__.id) + return module.__plugin__.proxy() + return module.__plugin__.subproxy(f"{plugin_name}{name}") + return module return __import__(name, fromlist=["__path__"]) @@ -67,7 +65,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore if body.level == 0: if len(body.names) == 1 and body.names[0].name == "*": new = ast.parse( - f"__mod = __check_import({body.module!r}, {self.name!r});" + f"__mod = __entari_import__({body.module!r}, {self.name!r});" f"__mod_all = getattr(__mod, '__all__', dir(__mod));" "globals().update(" "{name: getattr(__mod, name) for name in __mod_all if not name.startswith('__')}" @@ -76,7 +74,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore ) else: new = ast.parse( - f"__mod = __check_import({body.module!r}, {self.name!r});" + f"__mod = __entari_import__({body.module!r}, {self.name!r});" f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};" f"del __mod" ) @@ -91,7 +89,8 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore else: new = ast.parse( ";".join( - f"{alias.asname or alias.name}=__check_mod('{relative}{alias.name}', {self.name!r})" + f"{alias.asname or alias.name}=" + f"__entari_import__('{relative}{alias.name}', {self.name!r}, True)" for alias in body.names ) ) @@ -103,7 +102,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore relative = "." * body.level if len(body.names) == 1 and body.names[0].name == "*": new = ast.parse( - f"__mod = __check_mod('{relative}{body.module}', {self.name!r});" + f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);" f"__mod_all = getattr(__mod, '__all__', dir(__mod));" "globals().update(" "{name: getattr(__mod, name) for name in __mod_all if not name.startswith('__')}" @@ -112,7 +111,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore ) else: new = ast.parse( - f"__mod = __check_mod('{relative}{body.module}', {self.name!r});" + f"__mod = __entari_import__('{relative}{body.module}', {self.name!r}, True);" f"{';'.join(f'{alias.asname or alias.name} = __mod.{alias.name}' for alias in body.names)};" f"del __mod" ) @@ -125,7 +124,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore new = ast.parse( ",".join(aliases) + "=" - + ",".join(f"__check_import({alias.name!r}, {self.name!r})" for alias in body.names) + + ",".join(f"__entari_import__({alias.name!r}, {self.name!r})" for alias in body.names) ) for node in ast.walk(new): node.lineno = body.lineno # type: ignore @@ -134,6 +133,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore else: bodys.append(body) nodes.body = bodys + print(ast.unparse(nodes)) return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize) def create_module(self, spec) -> Optional[ModuleType]: @@ -147,8 +147,7 @@ def exec_module(self, module: ModuleType) -> None: if module.__name__ == plugin.module.__name__: # from . import xxxx return setattr(module, "__plugin__", plugin) - setattr(module, "__check_mod", _check_mod) - setattr(module, "__check_import", _check_import) + setattr(module, "__entari_import__", __entari_import__) try: super().exec_module(module) except Exception: @@ -165,8 +164,7 @@ def exec_module(self, module: ModuleType) -> None: # create plugin before executing plugin = Plugin(module.__name__, module) setattr(module, "__plugin__", plugin) - setattr(module, "__check_mod", _check_mod) - setattr(module, "__check_import", _check_import) + setattr(module, "__entari_import__", __entari_import__) # enter plugin context _plugin_token = _current_plugin.set(plugin)