diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py index b63ad15..45f420a 100644 --- a/arclet/entari/__init__.py +++ b/arclet/entari/__init__.py @@ -1,5 +1,4 @@ from arclet.letoderea import bind as bind -from arclet.letoderea import es as es from satori import ArgvInteraction as ArgvInteraction from satori import At as At from satori import Audio as Audio diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index 657c2e3..4f16265 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -12,11 +12,12 @@ from tarina.string import split from tarina.trie import CharTrie -from ..event import MessageCreatedEvent +from ..event.command import CommandExecute +from ..event.protocol import MessageCreatedEvent from ..message import MessageChain from .argv import MessageArgv # noqa: F401 from .model import CommandResult, Match, Query -from .plugin import CommandExecute, mount +from .plugin import mount from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd T = TypeVar("T") @@ -28,7 +29,7 @@ class EntariCommands: def __init__(self, need_tome: bool = False, remove_tome: bool = True): self.trie: CharTrie[Subscriber] = CharTrie() - self.publisher = Publisher("EntariCommands", MessageCreatedEvent) + self.publisher = Publisher("entari.command", MessageCreatedEvent) self.publisher.providers.append(AlconnaProviderFactory()) self.need_tome = need_tome self.remove_tome = remove_tome @@ -199,7 +200,7 @@ def config_commands(need_tome: bool = False, remove_tome: bool = True): async def execute(message: Union[str, MessageChain]): - return await es.post(CommandExecute(message), "entari.command/command_execute") + return await es.post(CommandExecute(message), "entari.event/command_execute") __all__ = ["_commands", "config_commands", "Match", "Query", "execute", "CommandResult", "mount", "command", "on"] diff --git a/arclet/entari/command/plugin.py b/arclet/entari/command/plugin.py index 654d737..85245a2 100644 --- a/arclet/entari/command/plugin.py +++ b/arclet/entari/command/plugin.py @@ -1,63 +1,19 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any -from arclet.alconna import Alconna, Arparma, command_manager, output_manager -from arclet.letoderea import AuxType, BaseAuxiliary, Contexts, Interface, Provider, ProviderFactory, Scope, es +from arclet.alconna import Alconna, command_manager +from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory from ..event import MessageCreatedEvent -from ..message import MessageChain +from ..event.command import pub as execute_handles from ..plugin.model import Plugin, PluginDispatcher -from .model import CommandResult, Match, Query -from .provider import AlconnaProviderFactory, AlconnaSuppiler, Assign, MessageJudger, _seminal +from .model import Match, Query +from .provider import AlconnaProviderFactory, AlconnaSuppiler, Assign, ExecuteSuppiler, MessageJudger, _seminal - -@dataclass -class CommandExecute: - command: str | MessageChain - - async def gather(self, context: Contexts): - if isinstance(self.command, str): - context["command"] = MessageChain(self.command) - else: - context["command"] = self.command - - class CommandProvider(Provider[MessageChain]): - async def __call__(self, context: Contexts): - return context.get("command") - - -execute_handles = es.define("entari.command/command_execute", CommandExecute) execute_handles.bind(AlconnaProviderFactory()) -class ExecuteJudger(BaseAuxiliary): - def __init__(self, cmd: Alconna): - self.cmd = cmd - super().__init__(AuxType.supply, priority=1) - - async def __call__(self, scope: Scope, interface: Interface): - message = interface.query(MessageChain, "command") - with output_manager.capture(self.cmd.name) as cap: - output_manager.set_action(lambda x: x, self.cmd.name) - try: - _res = self.cmd.parse(message) - except Exception as e: - _res = Arparma(self.cmd._hash, message, False, error_info=e) - may_help_text: str | None = cap.get("output", None) - result = CommandResult(self.cmd, _res, may_help_text) - return interface.update(alc_result=result) - - @property - def scopes(self) -> set[Scope]: - return {Scope.prepare} - - @property - def id(self) -> str: - return "entari.command/command_execute_judger" - - class AlconnaPluginDispatcher(PluginDispatcher): def __init__( @@ -70,8 +26,8 @@ def __init__( self.supplier = AlconnaSuppiler(command, need_tome, remove_tome) super().__init__(plugin, MessageCreatedEvent) - self.bind(MessageJudger(), self.supplier) - self.bind(AlconnaProviderFactory()) + self.publisher.bind(MessageJudger(), self.supplier) + self.publisher.bind(AlconnaProviderFactory()) def assign( self, @@ -99,8 +55,14 @@ def on_execute( providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None, ): _auxiliaries = auxiliaries or [] - _auxiliaries.append(ExecuteJudger(self.supplier.cmd)) - return execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers) + _auxiliaries.append(ExecuteSuppiler(self.supplier.cmd)) + + def wrapper(func): + sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers) + self._subscribers.append(sub) + return sub + + return wrapper Match = Match Query = Query @@ -110,7 +72,7 @@ def mount(cmd: Alconna, need_tome: bool = False, remove_tome: bool = True) -> Al if not (plugin := Plugin.current()): raise LookupError("no plugin context found") disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome) - if disp.id in plugin.dispatchers: + if disp.publisher.id in plugin.dispatchers: return plugin.dispatchers[disp.id] # type: ignore - plugin.dispatchers[disp.id] = disp + plugin.dispatchers[disp.publisher.id] = disp return disp diff --git a/arclet/entari/command/provider.py b/arclet/entari/command/provider.py index 9b827d1..989266b 100644 --- a/arclet/entari/command/provider.py +++ b/arclet/entari/command/provider.py @@ -88,7 +88,33 @@ def scopes(self) -> set[Scope]: @property def id(self) -> str: - return "entari.command/alconna_supplier" + return "entari.command/common_supplier" + + +class ExecuteSuppiler(SupplyAuxiliary): + def __init__(self, cmd: Alconna): + self.cmd = cmd + super().__init__(priority=1) + + async def __call__(self, scope: Scope, interface: Interface): + message = interface.query(MessageChain, "command") + with output_manager.capture(self.cmd.name) as cap: + output_manager.set_action(lambda x: x, self.cmd.name) + try: + _res = self.cmd.parse(message) + except Exception as e: + _res = Arparma(self.cmd._hash, message, False, error_info=e) + may_help_text: Optional[str] = cap.get("output", None) + result = CommandResult(self.cmd, _res, may_help_text) + return interface.update(alc_result=result) + + @property + def scopes(self) -> set[Scope]: + return {Scope.prepare} + + @property + def id(self) -> str: + return "entari.command/execute_supplier" class AlconnaProvider(Provider[Any]): diff --git a/arclet/entari/core.py b/arclet/entari/core.py index f2a6395..8a3c6be 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -17,7 +17,7 @@ from .command import _commands from .config import Config as EntariConfig -from .event import MessageCreatedEvent, event_parse +from .event.protocol import MessageCreatedEvent, event_parse from .plugin import load_plugin from .plugin.service import plugin_service from .session import Session @@ -34,6 +34,8 @@ def validate(self, param: Param): return get_origin(param.annotation) == Session async def __call__(self, context: Contexts): + if "session" in context and isinstance(context["session"], Session): + return context["session"] if "$origin_event" in context and "$account" in context: return Session(context["$account"], context["$event"]) diff --git a/arclet/entari/event/__init__.py b/arclet/entari/event/__init__.py new file mode 100644 index 0000000..7ff6021 --- /dev/null +++ b/arclet/entari/event/__init__.py @@ -0,0 +1,2 @@ +from .protocol import MessageCreatedEvent as MessageCreatedEvent +from .protocol import MessageEvent as MessageEvent diff --git a/arclet/entari/event/base.py b/arclet/entari/event/base.py new file mode 100644 index 0000000..89cf658 --- /dev/null +++ b/arclet/entari/event/base.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import Callable, TypeVar + +from ..plugin import dispatch + +TE = TypeVar("TE", bound="BasedEvent") + + +class BasedEvent: + @classmethod + def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None): + name = name or getattr(cls, "__disp_name__", None) + return dispatch(cls, predicate=predicate, name=name) # type: ignore diff --git a/arclet/entari/event/command.py b/arclet/entari/event/command.py new file mode 100644 index 0000000..ce0af0b --- /dev/null +++ b/arclet/entari/event/command.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Union + +from arclet.letoderea import Contexts, Provider, es + +from ..message import MessageChain +from .base import BasedEvent + + +@dataclass +class CommandExecute(BasedEvent): + command: Union[str, MessageChain] + + async def gather(self, context: Contexts): + if isinstance(self.command, str): + context["command"] = MessageChain(self.command) + else: + context["command"] = self.command + + class CommandProvider(Provider[MessageChain]): + async def __call__(self, context: Contexts): + return context.get("command") + + __disp_name__ = "entari.event/command_execute" + + +pub = es.define("entari.event/command_execute", CommandExecute, lambda x: {"command": x.command, "message": x.command}) diff --git a/arclet/entari/event.py b/arclet/entari/event/protocol.py similarity index 94% rename from arclet/entari/event.py rename to arclet/entari/event/protocol.py index 42fcff5..e847245 100644 --- a/arclet/entari/event.py +++ b/arclet/entari/event/protocol.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Callable, ClassVar, Generic, TypeVar +from typing import Any, ClassVar, Generic, TypeVar from arclet.letoderea import Contexts, Param, Provider from satori import ArgvInteraction, ButtonInteraction, Channel @@ -11,10 +11,9 @@ from satori.model import LoginType, MessageObject from tarina import gen_subclass -from .message import MessageChain -from .plugin import dispatch +from ..message import MessageChain +from .base import BasedEvent -TE = TypeVar("TE", bound="Event") T = TypeVar("T") D = TypeVar("D") @@ -25,6 +24,8 @@ def __init__(self, key: str | None = None): def __set_name__(self, owner: type[Event], name: str): self.key = self.key or name + if name not in ("id", "timestamp"): + owner._attrs.add(name) def __get__(self, instance: Event, owner: type[Event]) -> T: return getattr(instance._origin, self.key, None) # type: ignore @@ -37,8 +38,9 @@ def attr(key: str | None = None) -> Any: return Attr(key) -class Event: +class Event(BasedEvent): type: ClassVar[EventType] + _attrs: ClassVar[set[str]] = set() _origin: SatoriEvent account: Account @@ -55,27 +57,10 @@ class Event: role: Role | None = attr() user: User | None = attr() - _attrs: ClassVar[set[str]] = { - "argv", - "button", - "channel", - "guild", - "login", - "member", - "message", - "operator", - "role", - "user", - } - def __init__(self, account: Account, origin: SatoriEvent): self.account = account self._origin = origin - @classmethod - def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None): - return dispatch(cls, predicate=predicate) # type: ignore - async def gather(self, context: Contexts): context["$account"] = self.account context["$origin_event"] = self._origin diff --git a/arclet/entari/event/session.py b/arclet/entari/event/session.py new file mode 100644 index 0000000..433495a --- /dev/null +++ b/arclet/entari/event/session.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from arclet.letoderea import deref, es, provide +from arclet.letoderea.ref import generate + +from ..message import MessageChain +from .base import BasedEvent + +if TYPE_CHECKING: + from ..session import Session + + +@dataclass +class SendRequest(BasedEvent): + session: "Session" + message: "MessageChain" + + __disp_name__ = "entari.event/before_send" + + +pub = es.define("entari.event/before_send", SendRequest, lambda x: {"session": x.session, "message": x.message}) +pub.bind(provide(MessageChain, call=generate(deref(SendRequest).message))) diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index 209e755..0bb1932 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -18,13 +18,13 @@ from .service import plugin_service if TYPE_CHECKING: - from ..event import Event + from ..event.base import BasedEvent -def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = None): +def dispatch(*events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None): if not (plugin := _current_plugin.get(None)): raise LookupError("no plugin context found") - return plugin.dispatch(*events, predicate=predicate) + return plugin.dispatch(*events, predicate=predicate, name=name) def load_plugin(path: str, config: dict | None = None, recursive_guard: set[str] | None = None) -> Plugin | None: diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index f55be12..bb80492 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable +from collections.abc import Awaitable, Sequence from contextvars import ContextVar from dataclasses import dataclass, field from pathlib import Path @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import finalize, proxy -from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, es +from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory, Publisher, StepOut, Subscriber, es from arclet.letoderea.typing import TTarget from creart import it from launart import Launart, Service @@ -19,10 +19,11 @@ from .service import PluginLifecycleService, plugin_service if TYPE_CHECKING: - from ..event import Event + from ..event.base import BasedEvent _current_plugin: ContextVar[Plugin] = ContextVar("_current_plugin") +T = TypeVar("T") R = TypeVar("R") @@ -30,22 +31,30 @@ class RegisterNotInPluginError(Exception): pass -class PluginDispatcher(Publisher): +class PluginDispatcher: def __init__( self, plugin: Plugin, - *events: type[Event], - predicate: Callable[[Event], bool] | None = None, + *events: type[BasedEvent], + predicate: Callable[[BasedEvent], bool] | None = None, + name: str | None = None, ): - super().__init__(f"{plugin.id}@{id(self)}", *events, predicate=predicate) # type: ignore + id_ = f"#{plugin.id}@{name or id(self)}" + if name and name in es.publishers: + self.publisher = es.publishers[name] + elif id_ in es.publishers: + self.publisher = es.publishers[id_] + else: + self.publisher = Publisher(id_, *events, predicate=predicate) + es.register(self.publisher) self.plugin = plugin - es.register(self) self._events = events + self._subscribers = [] def waiter( self, - *events: type[Event], - providers: list[Provider | type[Provider]] | None = None, + *events: type[BasedEvent], + providers: Sequence[Provider | type[Provider]] | None = None, auxiliaries: list[BaseAuxiliary] | None = None, priority: int = 15, block: bool = False, @@ -59,22 +68,27 @@ def wrapper(func: TTarget[R]): return wrapper def dispose(self): - es.publishers.pop(self.id, None) - self.subscribers.clear() + for sub in self._subscribers: + sub.dispose() + self._subscribers.clear() if TYPE_CHECKING: register = Publisher.register else: - def register(self, func: Callable | None = None, **kwargs): - wrapper = super().register(**kwargs) + def register(self, func: Callable | None = None, **kwargs) -> Any: + wrapper = self.publisher.register(**kwargs) if func: self.plugin.validate(func) # type: ignore - return wrapper(func) + sub = wrapper(func) + self._subscribers.append(sub) + return sub def decorator(func1): self.plugin.validate(func1) - return wrapper(func1) + sub1 = wrapper(func1) + self._subscribers.append(sub1) + return sub1 return decorator @@ -205,13 +219,32 @@ def dispose(self): del plugin_service.plugins[self.id] del self.module - def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | None = None): - disp = PluginDispatcher(self, *events, predicate=predicate) - if disp.id in self.dispatchers: - return self.dispatchers[disp.id] - self.dispatchers[disp.id] = disp + def dispatch( + self, *events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None + ): + disp = PluginDispatcher(self, *events, predicate=predicate, name=name) + if disp.publisher.id in self.dispatchers: + return self.dispatchers[disp.publisher.id] + self.dispatchers[disp.publisher.id] = disp return disp + def use( + self, + pub_id: str, + *, + priority: int = 16, + auxiliaries: list[BaseAuxiliary] | None = None, + providers: ( + Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None + ) = None, + ) -> Callable[[Callable[..., Any]], Subscriber]: + if pub_id not in es.publishers: + raise LookupError(f"no publisher found: {pub_id}") + if not (disp := self.dispatchers.get(pub_id)): + disp = PluginDispatcher(self, name=pub_id) + self.dispatchers[disp.publisher.id] = disp + return disp.register(priority=priority, auxiliaries=auxiliaries, providers=providers) + def validate(self, func): if func.__module__ != self.module.__name__: if "__plugin__" in func.__globals__ and func.__globals__["__plugin__"] is self: @@ -251,9 +284,6 @@ def dispose(self): del self.obj -T = TypeVar("T") - - def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T: if not (plug := _current_plugin.get(None)): raise LookupError("no plugin context found") diff --git a/arclet/entari/session.py b/arclet/entari/session.py index 19bc2ee..8ada867 100644 --- a/arclet/entari/session.py +++ b/arclet/entari/session.py @@ -3,13 +3,13 @@ from collections.abc import Iterable from typing import Generic, NoReturn, TypeVar -from arclet.letoderea import ParsingStop, StepOut +from arclet.letoderea import ParsingStop, StepOut, es from satori.client.account import Account -from satori.client.protocol import ApiProtocol from satori.element import Element from satori.model import Channel, Guild, Member, MessageReceipt, PageResult, Role, User -from .event import Event, FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent +from .event.protocol import Event, FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent +from .event.session import SendRequest from .message import MessageChain TEvent = TypeVar("TEvent", bound=Event) @@ -106,17 +106,29 @@ def elements(self) -> MessageChain: return self._content raise RuntimeError(f"Event {self.context.type!r} has no Content") + @elements.setter + def elements(self, value: MessageChain): + self._content = value + def __getattr__(self, item): return getattr(self.account.protocol, item) + async def _send(self, channel_id: str, message: str | Iterable[str | Element]): + msg = MessageChain(message) + sess = self.__class__(self.account, self.context) + sess.elements = msg + res = await es.post(SendRequest(sess, msg), "entari.event/before_send") + if res and res.value is True: + return [] + return await self.account.send_message(channel_id, sess.elements) + async def send( self, message: str | Iterable[str | Element], - protocol_cls: type[ApiProtocol] | None = None, ) -> list[MessageReceipt]: - if not protocol_cls: - return await self.account.protocol.send(self.context, message) - return await self.account.custom(self.account.config, protocol_cls).send(self.context._origin, message) + if not self.context._origin.channel: + raise RuntimeError("Event cannot be replied to!") + return await self._send(self.context._origin.channel.id, message) async def send_message( self, @@ -129,7 +141,7 @@ async def send_message( """ if not self.context.channel: raise RuntimeError("Event cannot be replied to!") - return await self.account.protocol.send_message(self.context.channel, message) + return await self._send(self.context.channel.id, message) async def send_private_message( self, @@ -140,9 +152,8 @@ async def send_private_message( Args: message: 要发送的消息 """ - if not self.context.user: - raise RuntimeError("Event cannot be replied to!") - return await self.account.protocol.send_private_message(self.context.user, message) + channel = await self.user_channel_create() + return await self._send(channel.id, message) async def update_message( self, @@ -165,7 +176,7 @@ async def message_create( ) -> list[MessageReceipt]: if not self.context.channel: raise RuntimeError("Event cannot be replied to!") - return await self.account.protocol.message_create(self.context.channel.id, content) + return await self._send(self.context.channel.id, content) async def message_delete(self) -> None: if not self.context.channel: diff --git a/example_plugin.py b/example_plugin.py index 4195cf5..deaccf5 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -9,7 +9,7 @@ is_public_message, bind, metadata, - keeping + keeping, ) metadata(__file__) @@ -78,3 +78,8 @@ async def show(session: Session): print([*Plugin.current().dispatchers.keys()]) print(Plugin.current().subplugins) print("example_plugin not in sys.modules (expect True):", "example_plugin" not in sys.modules) + + +@plug.use("entari.event/before_send") +async def send_hook(session: Session, message: MessageChain): + session.elements = message + "喵"