Skip to content

Commit

Permalink
✨ use AST to resolve related import
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent 772d96a commit 44f159c
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 88 deletions.
32 changes: 6 additions & 26 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import asyncio
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload

from arclet.alconna import (
Alconna,
Arg,
Args,
Arparma,
CommandMeta,
Namespace,
command_manager,
config,
output_manager,
)
from arclet.alconna import Alconna, Arg, Args, Arparma, CommandMeta, Namespace, command_manager, config, output_manager
from arclet.alconna.tools.construct import AlconnaString, alconna_from_format
from arclet.alconna.typing import TAValue
from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber
Expand Down Expand Up @@ -54,9 +44,7 @@ async def listener(event: MessageCreatedEvent):
if not msg:
return
if matches := list(self.trie.prefixes(msg)):
await asyncio.gather(
*(depend_handler(res.value, event, inner=True) for res in matches if res.value)
)
await asyncio.gather(*(depend_handler(res.value, event, inner=True) for res in matches if res.value))
return
# shortcut
data = split(msg, (" ",))
Expand Down Expand Up @@ -116,9 +104,7 @@ def command(
need_tome: bool = False,
remove_tome: bool = True,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
):
class Command(AlconnaString):
def __call__(_cmd_self, func: TCallable) -> TCallable:
Expand All @@ -133,9 +119,7 @@ def on(
need_tome: bool = False,
remove_tome: bool = True,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
) -> Callable[[TCallable], TCallable]: ...

@overload
Expand All @@ -145,9 +129,7 @@ def on(
need_tome: bool = False,
remove_tome: bool = True,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
Expand All @@ -159,9 +141,7 @@ def on(
need_tome: bool = False,
remove_tome: bool = True,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
Expand Down
11 changes: 1 addition & 10 deletions arclet/entari/command/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@

from arclet.alconna import Alconna, Arparma, Duplication, Empty, output_manager
from arclet.alconna.builtin import generate_duplication
from arclet.letoderea import (
Contexts,
Interface,
JudgeAuxiliary,
Param,
Provider,
Scope,
Subscriber,
SupplyAuxiliary,
)
from arclet.letoderea import Contexts, Interface, JudgeAuxiliary, Param, Provider, Scope, Subscriber, SupplyAuxiliary
from arclet.letoderea.provider import ProviderFactory
from nepattern.util import CUnionType
from satori.client import Account
Expand Down
14 changes: 2 additions & 12 deletions arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import asyncio
from contextlib import suppress

from arclet.letoderea import (
BaseAuxiliary,
Contexts,
EventSystem,
Param,
Provider,
ProviderFactory,
global_providers,
)
from arclet.letoderea import BaseAuxiliary, Contexts, EventSystem, Param, Provider, ProviderFactory, global_providers
from launart import Launart
from loguru import logger
from satori import LoginStatus
Expand Down Expand Up @@ -80,9 +72,7 @@ def on_message(
auxiliaries: list[BaseAuxiliary] | None = None,
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
return self.event_system.on(
MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers
)
return self.event_system.on(MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers)

def ensure_manager(self, manager: Launart):
self.manager = manager
Expand Down
2 changes: 1 addition & 1 deletion arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def load_plugin(path: str) -> Plugin | None:
except RegisterNotInPluginError as e:
logger.exception(f"{e.args[0]}", exc_info=e)
except Exception as e:
logger.error(f"failed to load plugin {path!r} caused by {e!r}")
logger.exception(f"failed to load plugin {path!r} caused by {e!r}", exc_info=e)


def load_plugins(dir_: str | PathLike | Path):
Expand Down
2 changes: 2 additions & 0 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def dispose(self):
if self.module.__spec__ and self.module.__spec__.cached:
Path(self.module.__spec__.cached).unlink(missing_ok=True)
sys.modules.pop(self.module.__name__, None)
delattr(self.module, "__plugin__")
for submod in self.submodules.values():
delattr(submod, "__plugin__")
sys.modules.pop(submod.__name__, None)
if submod.__spec__ and submod.__spec__.cached:
Path(submod.__spec__.cached).unlink(missing_ok=True)
Expand Down
100 changes: 90 additions & 10 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ast
from collections.abc import Sequence
from importlib import _bootstrap # type: ignore
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader
from importlib.util import module_from_spec, resolve_name
Expand All @@ -9,12 +11,21 @@
from .model import Plugin, PluginMetadata, _current_plugin
from .service import service

_SUBMODULE_WAITLIST = set()
_SUBMODULE_WAITLIST: dict[str, set[str]] = {}


def package(*names: str):
"""手动指定特定模块作为插件的子模块"""
_SUBMODULE_WAITLIST.update(names)
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
_SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names)


def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]):
if mod == ".":
return tuple(import_plugin(f".{alias}", __fullname) for alias in aliases)
_mod = import_plugin(f".{mod}", __fullname) if mod else import_plugin(__fullname)
return tuple(getattr(_mod, alias) for alias in aliases)


class PluginLoader(SourceFileLoader):
Expand All @@ -23,19 +34,83 @@ def __init__(self, fullname: str, path: str, parent_plugin_id: Optional[str] = N
self.parent_plugin_id = parent_plugin_id
super().__init__(fullname, path)

def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
"""Return the code object compiled from source.
The 'data' argument can be any object type that compile() supports.
"""
nodes = ast.parse(data, type_comments=True)
for i, body in enumerate(nodes.body):
if isinstance(body, ast.ImportFrom):
if body.level == 0 and (
body.module in _SUBMODULE_WAITLIST.get(self.name, ()) or body.module in service.plugins
):
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '', {[alias.name for alias in body.names]!r})"
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
if body.level == 1:
if body.module is None:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})"
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
else:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ (
f"=__unpack_import_from('{self.name}', {body.module!r}, "
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
elif (
isinstance(body, ast.Expr)
and isinstance(body.value, ast.Call)
and isinstance(body.value.func, ast.Name)
and body.value.func.id == "package"
):
if body.value.args and isinstance(body.value.args[0], ast.Constant):
_SUBMODULE_WAITLIST.setdefault(self.name, set()).update(arg.value for arg in body.value.args) # type: ignore
elif isinstance(body, ast.Import):
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ "="
+ ",".join(
(
f"__import_plugin({alias.name!r})"
if (alias.name in _SUBMODULE_WAITLIST.get(self.name, ()) or alias.name in service.plugins)
else f"__import__({alias.name!r})"
)
for alias in body.names
)
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize)

def create_module(self, spec) -> Optional[ModuleType]:
if self.name in service.plugins:
self.loaded = True
return service.plugins[self.name].module
return super().create_module(spec)

def exec_module(self, module: ModuleType) -> None:
if plugin := _current_plugin.get(
service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None
):
if plugin := _current_plugin.get(service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None):
if module.__name__ == plugin.module.__name__: # from . import xxxx
return
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__import_plugin", import_plugin)
try:
super().exec_module(module)
except Exception:
Expand All @@ -51,6 +126,8 @@ def exec_module(self, module: ModuleType) -> None:
# create plugin before executing
plugin = Plugin(module.__name__, module)
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__import_plugin", import_plugin)

# enter plugin context
_plugin_token = _current_plugin.set(plugin)
Expand All @@ -75,7 +152,10 @@ def find_spec(name, package=None):
fullname = resolve_name(name, package) if name.startswith(".") else name
parent_name = fullname.rpartition(".")[0]
if parent_name:
parent = __import__(parent_name, fromlist=["__path__"])
if parent_name in service.plugins:
parent = service.plugins[parent_name].module
else:
parent = __import__(parent_name, fromlist=["__path__"])
try:
parent_path = parent.__path__
except AttributeError as e:
Expand Down Expand Up @@ -122,11 +202,11 @@ def find_spec(
if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin:
return plug.module.__spec__
if module_spec.parent and module_spec.parent == plug.module.__name__:
module_spec.loader = PluginLoader(fullname, module_origin)
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
elif module_spec.name in _SUBMODULE_WAITLIST:
module_spec.loader = PluginLoader(fullname, module_origin)
_SUBMODULE_WAITLIST.remove(module_spec.name)
elif module_spec.name in _SUBMODULE_WAITLIST[plug.module.__name__]:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
# _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name)
return module_spec

if module_spec.name in service.plugins:
Expand Down
34 changes: 8 additions & 26 deletions arclet/entari/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ async def waiter(content: MessageChain, session: Session[MessageEvent]):
if self.context.channel:
if self.context.channel.id == session.context.channel.id and (
not keep_sender
or (
self.context.user
and session.context.user
and self.context.user.id == session.context.user.id
)
or (self.context.user and session.context.user and self.context.user.id == session.context.user.id)
):
return content
elif self.context.user:
Expand Down Expand Up @@ -120,9 +116,7 @@ async def send(
) -> list[MessageObject]:
if not protocol_cls:
return await self.account.protocol.send(self.context, message)
return await self.account.custom(self.account.config, protocol_cls).send(
self.context._origin, message
)
return await self.account.custom(self.account.config, protocol_cls).send(self.context._origin, message)

async def send_message(
self,
Expand Down Expand Up @@ -163,9 +157,7 @@ async def update_message(
raise RuntimeError("Event cannot be replied to!")
if not self.context.message:
raise RuntimeError("Event cannot update message")
return await self.account.protocol.update_message(
self.context.channel, self.context.message.id, message
)
return await self.account.protocol.update_message(self.context.channel, self.context.message.id, message)

async def message_create(
self,
Expand Down Expand Up @@ -252,9 +244,7 @@ async def guild_member_kick(self, user_id: str | None = None, permanent: bool =
return await self.account.protocol.guild_member_kick(self.context.guild.id, user_id, permanent)
if not self.context.user:
raise RuntimeError("Event cannot use to kick member!")
return await self.account.protocol.guild_member_kick(
self.context.guild.id, self.context.user.id, permanent
)
return await self.account.protocol.guild_member_kick(self.context.guild.id, self.context.user.id, permanent)

async def guild_member_role_set(self, role_id: str, user_id: str | None = None) -> None:
if not self.context.guild:
Expand All @@ -263,22 +253,16 @@ async def guild_member_role_set(self, role_id: str, user_id: str | None = None)
return await self.account.protocol.guild_member_role_set(self.context.guild.id, user_id, role_id)
if not self.context.user:
raise RuntimeError("Event cannot use to guild member role set!")
return await self.account.protocol.guild_member_role_set(
self.context.guild.id, self.context.user.id, role_id
)
return await self.account.protocol.guild_member_role_set(self.context.guild.id, self.context.user.id, role_id)

async def guild_member_role_unset(self, role_id: str, user_id: str | None = None) -> None:
if not self.context.guild:
raise RuntimeError("Event cannot use to guild member role unset!")
if user_id:
return await self.account.protocol.guild_member_role_unset(
self.context.guild.id, user_id, role_id
)
return await self.account.protocol.guild_member_role_unset(self.context.guild.id, user_id, role_id)
if not self.context.user:
raise RuntimeError("Event cannot use to guild member role unset!")
return await self.account.protocol.guild_member_role_unset(
self.context.guild.id, self.context.user.id, role_id
)
return await self.account.protocol.guild_member_role_unset(self.context.guild.id, self.context.user.id, role_id)

async def guild_role_list(self, next_token: str | None = None) -> PageResult[Role]:
if not self.context.guild:
Expand Down Expand Up @@ -318,9 +302,7 @@ async def reaction_create(
raise RuntimeError("Event cannot be replied to!")
if not self.context.message:
raise RuntimeError("Event cannot create reaction")
return await self.account.protocol.reaction_create(
self.context.channel.id, self.context.message.id, emoji
)
return await self.account.protocol.reaction_create(self.context.channel.id, self.context.message.id, emoji)

async def reaction_delete(
self,
Expand Down
Loading

0 comments on commit 44f159c

Please sign in to comment.