Skip to content

Commit

Permalink
🐛 fix ImportFrom level 0
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent afa634f commit 72efb23
Showing 1 changed file with 54 additions and 44 deletions.
98 changes: 54 additions & 44 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ def _check_mod(name, package=None):
return module


def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]):
if mod == ".":
def _unpack_import_from_level_x(__fullname: str, mod: str, level: int, aliases: list[str]):
if not mod:
if len(aliases) == 1:
return _check_mod(f".{aliases[0]}", __fullname)
return tuple(_check_mod(f".{alias}", __fullname) for alias in aliases)
_mod = _check_mod(f".{mod}", __fullname) if mod else _check_mod(__fullname)
return _check_mod(f"{'.' * level}{aliases[0]}", __fullname)
return tuple(_check_mod(f"{'.' * level}{alias}", __fullname) for alias in aliases)
_mod = _check_mod(f"{'.' * level}{mod}", __fullname) # if mod else _check_mod(__fullname)
if len(aliases) == 1:
return getattr(_mod, aliases[0])
return tuple(getattr(_mod, alias) for alias in aliases)
args = []
for alias in aliases:
args.append(getattr(_mod, alias))
return tuple(args)


def _check_import(name: str, plugin_name: str):
Expand All @@ -57,7 +60,17 @@ def _check_import(name: str, plugin_name: str):
if plugin_name != mod.__plugin__.id:
service._referents[mod.__plugin__.id].add(plugin_name)
return mod.__plugin__.subproxy(name)
return __import__(name)
return __import__(name, fromlist=["__path__"])


def _unpack_import_from_level_0(name, plugin_name, aliases):
mod = _check_import(name, plugin_name)
if len(aliases) == 1:
return getattr(mod, aliases[0])
args = []
for alias in aliases:
args.append(getattr(mod, alias))
return tuple(args)


class PluginLoader(SourceFileLoader):
Expand All @@ -74,53 +87,48 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
nodes = ast.parse(data, type_comments=True)
for i, body in enumerate(nodes.body):
if isinstance(body, ast.ImportFrom):
if body.level == 0 and (
body.module in _SUBMODULE_WAITLIST.get(self.name, ()) or body.module in service.plugins
):
if body.level == 0:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ (
f"=__unpack_import_from_level_0({body.module!r}, {self.name!r}, "
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
elif body.module is None:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ (
f"=__unpack_import_from_level_x('{self.name}', '', {body.level}, "
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
else:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{body.module}', '', {[alias.name for alias in body.names]!r})"
+ (
f"=__unpack_import_from_level_x('{self.name}', {body.module!r}, {body.level}, "
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
if body.level == 1:
if body.module is None:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})"
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
else:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ (
f"=__unpack_import_from('{self.name}', {body.module!r}, "
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
elif (
isinstance(body, ast.Expr)
and isinstance(body.value, ast.Call)
and isinstance(body.value.func, ast.Name)
and body.value.func.id == "package"
):
if body.value.args and isinstance(body.value.args[0], ast.Constant):
_SUBMODULE_WAITLIST.setdefault(self.name, set()).update(arg.value for arg in body.value.args) # type: ignore
elif isinstance(body, ast.Import):
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ "="
+ ",".join((f"__check_import({alias.name!r}, {self.name!r})") for alias in body.names)
+ ",".join(f"__check_import({alias.name!r}, {self.name!r})" for alias in body.names)
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
Expand All @@ -138,7 +146,8 @@ def exec_module(self, module: ModuleType) -> None:
if module.__name__ == plugin.module.__name__: # from . import xxxx
return
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__unpack_import_from_level_x", _unpack_import_from_level_x)
setattr(module, "__unpack_import_from_level_0", _unpack_import_from_level_0)
setattr(module, "__check_import", _check_import)
try:
super().exec_module(module)
Expand All @@ -156,7 +165,8 @@ def exec_module(self, module: ModuleType) -> None:
# create plugin before executing
plugin = Plugin(module.__name__, module)
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__unpack_import_from_level_x", _unpack_import_from_level_x)
setattr(module, "__unpack_import_from_level_0", _unpack_import_from_level_0)
setattr(module, "__check_import", _check_import)

# enter plugin context
Expand Down

0 comments on commit 72efb23

Please sign in to comment.