diff --git a/README.md b/README.md index 0e719c7..867a007 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ app.run() 编写插件: ```python -from arclet.entari import Session, MessageCreatedEvent, metadata +from arclet.entari import Session, MessageCreatedEvent, metadata, plugin metadata( name="Hello, World!", @@ -95,7 +95,7 @@ metadata( ) # or __plugin_metadata__ = PluginMetadata(...) -@MessageCreatedEvent.dispatch() +@plugin.dispatch(MessageCreatedEvent) async def _(session: Session): await session.send("Hello, World!") ``` diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py index 1d580fc..a04a9eb 100644 --- a/arclet/entari/__init__.py +++ b/arclet/entari/__init__.py @@ -49,6 +49,7 @@ from .message import MessageChain as MessageChain from .plugin import Plugin as Plugin from .plugin import PluginMetadata as PluginMetadata +from .plugin import declare_static as declare_static from .plugin import dispose as dispose_plugin # noqa: F401 from .plugin import keeping as keeping from .plugin import load_plugin as load_plugin diff --git a/arclet/entari/builtins/auto_reload.py b/arclet/entari/builtins/auto_reload.py index 64eb8bc..05f70b7 100644 --- a/arclet/entari/builtins/auto_reload.py +++ b/arclet/entari/builtins/auto_reload.py @@ -7,12 +7,14 @@ from launart.status import Phase from watchfiles import PythonFilter, awatch -from arclet.entari import Plugin, dispose_plugin, load_plugin, metadata +from arclet.entari import Plugin, declare_static, dispose_plugin, load_plugin, metadata from arclet.entari.config import EntariConfig from arclet.entari.event.config import ConfigReload from arclet.entari.logger import log from arclet.entari.plugin import find_plugin, find_plugin_by_file +declare_static() + class Config: watch_dirs: list[str] = ["."] @@ -66,7 +68,7 @@ async def watch(self): async for event in awatch(*self.dirs, watch_filter=PythonFilter()): for change in event: if plugin := find_plugin_by_file(change[1]): - if plugin.id == "arclet.entari.builtins.auto_reload" or plugin.is_static: + if plugin.is_static: logger("INFO", f"Plugin {plugin.id!r} is static, ignored.") continue logger("INFO", f"Detected change in {plugin.id!r}, reloading...") diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index 61e4852..c687e69 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -14,9 +14,9 @@ from tarina.string import split from tarina.trie import CharTrie +from ..event.base import MessageCreatedEvent from ..event.command import CommandExecute from ..event.config import ConfigReload -from ..event.protocol import MessageCreatedEvent from ..message import MessageChain from ..plugin import RootlessPlugin from ..session import Session diff --git a/arclet/entari/core.py b/arclet/entari/core.py index 57cef3e..91062b8 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -15,9 +15,9 @@ from tarina.generic import get_origin from .config import EntariConfig +from .event.base import MessageCreatedEvent, event_parse from .event.config import ConfigReload from .event.lifespan import AccountUpdate -from .event.protocol import MessageCreatedEvent, event_parse from .event.send import SendResponse from .logger import log from .plugin import load_plugin, plugin_config, requires @@ -130,7 +130,7 @@ def __init__( log.core.opt(colors=True).debug(f"Log level set to {log_level}") requires(*EntariConfig.instance.prelude_plugin) for plug in EntariConfig.instance.prelude_plugin: - load_plugin(plug, static=True) + load_plugin(plug, prelude=True) requires(*EntariConfig.instance.plugin) for plug in EntariConfig.instance.plugin: load_plugin(plug) diff --git a/arclet/entari/event/__init__.py b/arclet/entari/event/__init__.py index 7ff6021..a462a0d 100644 --- a/arclet/entari/event/__init__.py +++ b/arclet/entari/event/__init__.py @@ -1,2 +1,2 @@ -from .protocol import MessageCreatedEvent as MessageCreatedEvent -from .protocol import MessageEvent as MessageEvent +from .base import MessageCreatedEvent as MessageCreatedEvent +from .base import MessageEvent as MessageEvent diff --git a/arclet/entari/event/base.py b/arclet/entari/event/base.py index ab06489..4130e0f 100644 --- a/arclet/entari/event/base.py +++ b/arclet/entari/event/base.py @@ -1,14 +1,403 @@ from __future__ import annotations -from typing import Callable, TypeVar +from dataclasses import dataclass +from datetime import datetime +from typing import Any, ClassVar, Generic, TypeVar -TE = TypeVar("TE", bound="BasedEvent") +from arclet.letoderea import Contexts, Param, Provider +from satori import ArgvInteraction, ButtonInteraction, Channel +from satori import Event as OriginEvent +from satori import EventType, Guild, Member, Role, User +from satori.client import Account +from satori.element import At, Author, Quote, Text, select +from satori.model import Login, MessageObject +from tarina import gen_subclass +from ..message import MessageChain -class BasedEvent: - @classmethod - def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None): - from ..plugin import dispatch +T = TypeVar("T") +D = TypeVar("D") - name = name or getattr(cls, "__publisher__", None) - return dispatch(cls, predicate=predicate, name=name) # type: ignore + +@dataclass +class Reply: + quote: Quote + origin: MessageObject + + +def _is_reply_me(reply: Reply, account: Account): + if reply.origin.user: + return reply.origin.user.id == account.self_id + if authors := select(reply.quote, Author): + return any(author.id == account.self_id for author in authors) + return False + + +def _is_notice_me(message: MessageChain, account: Account): + if message and isinstance(message[0], At): + at: At = message[0] # type: ignore + if at.id and at.id == account.self_id: + return True + return False + + +def _remove_notice_me(message: MessageChain, account: Account): + message = message.copy() + message.pop(0) + if _is_notice_me(message, account): + message.pop(0) + if message and isinstance(message[0], Text): + text = message[0].text.lstrip() # type: ignore + if not text: + message.pop(0) + else: + message[0] = Text(text) + return message + + +class Attr(Generic[T]): + def __init__(self, key: str | None = None): + self.key = key + + def __set_name__(self, owner: type[SatoriEvent], name: str): + self.key = self.key or name + if name not in ("id", "timestamp"): + owner._attrs.add(name) + + def __get__(self, instance: SatoriEvent, owner: type[SatoriEvent]) -> T: + return getattr(instance._origin, self.key, None) # type: ignore + + def __set__(self, instance: SatoriEvent, value): + raise AttributeError("can't set attribute") + + +def attr(key: str | None = None) -> Any: + return Attr(key) + + +class SatoriEvent: + type: ClassVar[EventType] + _attrs: ClassVar[set[str]] = set() + _origin: OriginEvent + account: Account + + sn: int = attr() + timestamp: datetime = attr() + login: Login = attr() + argv: ArgvInteraction | None = attr() + button: ButtonInteraction | None = attr() + channel: Channel | None = attr() + guild: Guild | None = attr() + member: Member | None = attr() + message: MessageObject | None = attr() + operator: User | None = attr() + role: Role | None = attr() + user: User | None = attr() + + def __init__(self, account: Account, origin: OriginEvent): + self.account = account + self._origin = origin + + async def gather(self, context: Contexts): + context["account"] = self.account + context["$origin_event"] = self._origin + + for name in self.__class__._attrs: + value = getattr(self, name) + if value is not None: + context["$message_origin" if name == "message" else name] = value + + class TimeProvider(Provider[datetime]): + async def __call__(self, context: Contexts): + if "$event" in context: + return context["$event"].timestamp + + class OperatorProvider(Provider[User]): + priority = 10 + + def validate(self, param: Param): + return param.name == "operator" and super().validate(param) + + async def __call__(self, context: Contexts): + if "operator" in context: + return context["operator"] + if "$origin_event" not in context: + return + return context["$origin_event"].operator + + class UserProvider(Provider[User]): + async def __call__(self, context: Contexts): + if "user" in context: + return context["user"] + if "$origin_event" not in context: + return + return context["$origin_event"].user + + class MessageProvider(Provider[MessageObject]): + async def __call__(self, context: Contexts): + if "$message_origin" in context: + return context["$message_origin"] + if "$origin_event" not in context: + return + return context["$origin_event"].message + + class ChannelProvider(Provider[Channel]): + async def __call__(self, context: Contexts): + if "channel" in context: + return context["channel"] + if "$origin_event" not in context: + return + return context["$origin_event"].channel + + class GuildProvider(Provider[Guild]): + async def __call__(self, context: Contexts): + if "guild" in context: + return context["guild"] + if "$origin_event" not in context: + return + return context["$origin_event"].guild + + class MemberProvider(Provider[Member]): + async def __call__(self, context: Contexts): + if "member" in context: + return context["member"] + if "$origin_event" not in context: + return + return context["$origin_event"].member + + class RoleProvider(Provider[Role]): + async def __call__(self, context: Contexts): + if "role" in context: + return context["role"] + if "$origin_event" not in context: + return + return context["$origin_event"].role + + class LoginProvider(Provider[Login]): + async def __call__(self, context: Contexts): + if "login" in context: + return context["login"] + if "$origin_event" not in context: + return + return context["$origin_event"].login + + def __repr__(self): + return f"<{self.__class__.__name__[:-5]}{self._origin!r}>" + + +class NoticeEvent(SatoriEvent): + pass + + +class FriendEvent(NoticeEvent): + user: User = attr() + + +class FriendRequestEvent(FriendEvent): + type = EventType.FRIEND_REQUEST + + message: MessageObject = attr() + + +class GuildEvent(NoticeEvent): + guild: Guild = attr() + + async def gather(self, context: Contexts): + await super().gather(context) + context["guild"] = self.guild + + +class GuildAddedEvent(GuildEvent): + type = EventType.GUILD_ADDED + + +class GuildRemovedEvent(GuildEvent): + type = EventType.GUILD_REMOVED + + +class GuildRequestEvent(GuildEvent): + type = EventType.GUILD_REQUEST + + message: MessageObject = attr() + + +class GuildUpdatedEvent(GuildEvent): + type = EventType.GUILD_UPDATED + + +class GuildMemberEvent(GuildEvent): + user: User = attr() + + +class GuildMemberAddedEvent(GuildMemberEvent): + type = EventType.GUILD_MEMBER_ADDED + + +class GuildMemberRemovedEvent(GuildMemberEvent): + type = EventType.GUILD_MEMBER_REMOVED + + +class GuildMemberRequestEvent(GuildMemberEvent): + type = EventType.GUILD_MEMBER_REQUEST + + message: MessageObject = attr() + + +class GuildMemberUpdatedEvent(GuildMemberEvent): + type = EventType.GUILD_MEMBER_UPDATED + + +class GuildRoleEvent(GuildEvent): + role: Role = attr() + + +class GuildRoleCreatedEvent(GuildRoleEvent): + type = EventType.GUILD_ROLE_CREATED + + +class GuildRoleDeletedEvent(GuildRoleEvent): + type = EventType.GUILD_ROLE_DELETED + + +class GuildRoleUpdatedEvent(GuildRoleEvent): + type = EventType.GUILD_ROLE_UPDATED + + +class LoginEvent(NoticeEvent): + pass + + +class LoginAddedEvent(LoginEvent): + type = EventType.LOGIN_ADDED + + +class LoginRemovedEvent(LoginEvent): + type = EventType.LOGIN_REMOVED + + +class LoginUpdatedEvent(LoginEvent): + type = EventType.LOGIN_UPDATED + + +class MessageContentProvider(Provider[MessageChain]): + priority = 30 + + async def __call__(self, context: Contexts): + return context.get("$message_content") + + +class ReplyProvider(Provider[Reply]): + async def __call__(self, context: Contexts): + return context.get("$message_reply") + + +class MessageEvent(SatoriEvent): + channel: Channel = attr() + user: User = attr() + message: MessageObject = attr() + + content: MessageChain + quote: Quote | None = None + + providers = [MessageContentProvider, ReplyProvider] + + def __init__(self, account: Account, origin: OriginEvent): + super().__init__(account, origin) + self.content = MessageChain(self.message.message) + if self.content.has(Quote): + self.quote = self.content.get(Quote, 1)[0] + self.content = self.content.exclude(Quote) + + async def gather(self, context: Contexts): + await super().gather(context) + reply = None + if self.quote and self.quote.id: + mo = await self.account.protocol.message_get(self.channel.id, self.quote.id) + reply = context["$message_reply"] = Reply(self.quote, mo) + if not reply: + is_reply_me = False + else: + is_reply_me = _is_reply_me(reply, self.account) + context["is_reply_me"] = is_reply_me + if is_reply_me and self.content and isinstance(self.content[0], Text): + text = self.content[0].text.lstrip() + if not text: + self.content.pop(0) + else: + self.content[0] = Text(text) + is_notice_me = context["is_notice_me"] = _is_notice_me(self.content, self.account) + if is_notice_me: + self.content = _remove_notice_me(self.content, self.account) + context["$message_content"] = self.content + + +class MessageCreatedEvent(MessageEvent): + type = EventType.MESSAGE_CREATED + + +class MessageDeletedEvent(MessageEvent): + type = EventType.MESSAGE_DELETED + + +class MessageUpdatedEvent(MessageEvent): + type = EventType.MESSAGE_UPDATED + + +class ReactionEvent(NoticeEvent, MessageEvent): + pass + + +class ReactionAddedEvent(ReactionEvent): + type = EventType.REACTION_ADDED + + +class ReactionRemovedEvent(ReactionEvent): + type = EventType.REACTION_REMOVED + + +class InternalEvent(SatoriEvent): + type = EventType.INTERNAL + + +class InteractionEvent(NoticeEvent): + pass + + +class InteractionButtonEvent(InteractionEvent): + type = EventType.INTERACTION_BUTTON + + button: ButtonInteraction = attr() + + class ButtonProvider(Provider[ButtonInteraction]): + async def __call__(self, context: Contexts): + return context.get("button") + + +class InteractionCommandEvent(InteractionEvent): + type = EventType.INTERACTION_COMMAND + + +class InteractionCommandArgvEvent(InteractionCommandEvent): + argv: ArgvInteraction = attr() + + class ArgvProvider(Provider[ArgvInteraction]): + async def __call__(self, context: Contexts): + return context.get("argv") + + +class InteractionCommandMessageEvent(InteractionCommandEvent, MessageEvent): + pass + + +MAPPING: dict[str, type[SatoriEvent]] = {} + +for cls in gen_subclass(SatoriEvent): + if hasattr(cls, "type"): + MAPPING[cls.type.value] = cls + + +def event_parse(account: Account, event: OriginEvent): + try: + return MAPPING[event.type](account, event) + except KeyError: + raise NotImplementedError from None diff --git a/arclet/entari/event/command.py b/arclet/entari/event/command.py index 3621481..e50aa97 100644 --- a/arclet/entari/event/command.py +++ b/arclet/entari/event/command.py @@ -4,11 +4,10 @@ from arclet.letoderea import Contexts, Provider, es from ..message import MessageChain -from .base import BasedEvent @dataclass -class CommandExecute(BasedEvent): +class CommandExecute: command: Union[str, MessageChain] async def gather(self, context: Contexts): diff --git a/arclet/entari/event/config.py b/arclet/entari/event/config.py index df75c9c..1e2a60b 100644 --- a/arclet/entari/event/config.py +++ b/arclet/entari/event/config.py @@ -3,11 +3,9 @@ from arclet.letoderea import es -from .base import BasedEvent - @dataclass -class ConfigReload(BasedEvent): +class ConfigReload: scope: str key: str value: Any diff --git a/arclet/entari/event/lifespan.py b/arclet/entari/event/lifespan.py index 9af356e..4a9860c 100644 --- a/arclet/entari/event/lifespan.py +++ b/arclet/entari/event/lifespan.py @@ -4,32 +4,30 @@ from satori.client import Account from satori.model import LoginStatus -from .base import BasedEvent - @dataclass -class Startup(BasedEvent): +class Startup: pass __publisher__ = "entari.event/startup" @dataclass -class Ready(BasedEvent): +class Ready: pass __publisher__ = "entari.event/ready" @dataclass -class Cleanup(BasedEvent): +class Cleanup: pass __publisher__ = "entari.event/cleanup" @dataclass -class AccountUpdate(BasedEvent): +class AccountUpdate: account: Account status: LoginStatus diff --git a/arclet/entari/event/protocol.py b/arclet/entari/event/protocol.py deleted file mode 100644 index e423947..0000000 --- a/arclet/entari/event/protocol.py +++ /dev/null @@ -1,404 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime -from typing import Any, ClassVar, Generic, TypeVar - -from arclet.letoderea import Contexts, Param, Provider -from satori import ArgvInteraction, ButtonInteraction, Channel -from satori import Event as SatoriEvent -from satori import EventType, Guild, Member, Role, User -from satori.client import Account -from satori.element import At, Author, Quote, Text, select -from satori.model import Login, MessageObject -from tarina import gen_subclass - -from ..message import MessageChain -from .base import BasedEvent - -T = TypeVar("T") -D = TypeVar("D") - - -@dataclass -class Reply: - quote: Quote - origin: MessageObject - - -def _is_reply_me(reply: Reply, account: Account): - if reply.origin.user: - return reply.origin.user.id == account.self_id - if authors := select(reply.quote, Author): - return any(author.id == account.self_id for author in authors) - return False - - -def _is_notice_me(message: MessageChain, account: Account): - if message and isinstance(message[0], At): - at: At = message[0] # type: ignore - if at.id and at.id == account.self_id: - return True - return False - - -def _remove_notice_me(message: MessageChain, account: Account): - message = message.copy() - message.pop(0) - if _is_notice_me(message, account): - message.pop(0) - if message and isinstance(message[0], Text): - text = message[0].text.lstrip() # type: ignore - if not text: - message.pop(0) - else: - message[0] = Text(text) - return message - - -class Attr(Generic[T]): - def __init__(self, key: str | None = None): - self.key = key - - def __set_name__(self, owner: type[Event], name: str): - self.key = self.key or name - if name not in ("id", "timestamp"): - owner._attrs.add(name) - - def __get__(self, instance: Event, owner: type[Event]) -> T: - return getattr(instance._origin, self.key, None) # type: ignore - - def __set__(self, instance: Event, value): - raise AttributeError("can't set attribute") - - -def attr(key: str | None = None) -> Any: - return Attr(key) - - -class Event(BasedEvent): - type: ClassVar[EventType] - _attrs: ClassVar[set[str]] = set() - _origin: SatoriEvent - account: Account - - sn: int = attr() - timestamp: datetime = attr() - login: Login = attr() - argv: ArgvInteraction | None = attr() - button: ButtonInteraction | None = attr() - channel: Channel | None = attr() - guild: Guild | None = attr() - member: Member | None = attr() - message: MessageObject | None = attr() - operator: User | None = attr() - role: Role | None = attr() - user: User | None = attr() - - def __init__(self, account: Account, origin: SatoriEvent): - self.account = account - self._origin = origin - - async def gather(self, context: Contexts): - context["account"] = self.account - context["$origin_event"] = self._origin - - for name in self.__class__._attrs: - value = getattr(self, name) - if value is not None: - context["$message_origin" if name == "message" else name] = value - - class TimeProvider(Provider[datetime]): - async def __call__(self, context: Contexts): - if "$event" in context: - return context["$event"].timestamp - - class OperatorProvider(Provider[User]): - priority = 10 - - def validate(self, param: Param): - return param.name == "operator" and super().validate(param) - - async def __call__(self, context: Contexts): - if "operator" in context: - return context["operator"] - if "$origin_event" not in context: - return - return context["$origin_event"].operator - - class UserProvider(Provider[User]): - async def __call__(self, context: Contexts): - if "user" in context: - return context["user"] - if "$origin_event" not in context: - return - return context["$origin_event"].user - - class MessageProvider(Provider[MessageObject]): - async def __call__(self, context: Contexts): - if "$message_origin" in context: - return context["$message_origin"] - if "$origin_event" not in context: - return - return context["$origin_event"].message - - class ChannelProvider(Provider[Channel]): - async def __call__(self, context: Contexts): - if "channel" in context: - return context["channel"] - if "$origin_event" not in context: - return - return context["$origin_event"].channel - - class GuildProvider(Provider[Guild]): - async def __call__(self, context: Contexts): - if "guild" in context: - return context["guild"] - if "$origin_event" not in context: - return - return context["$origin_event"].guild - - class MemberProvider(Provider[Member]): - async def __call__(self, context: Contexts): - if "member" in context: - return context["member"] - if "$origin_event" not in context: - return - return context["$origin_event"].member - - class RoleProvider(Provider[Role]): - async def __call__(self, context: Contexts): - if "role" in context: - return context["role"] - if "$origin_event" not in context: - return - return context["$origin_event"].role - - class LoginProvider(Provider[Login]): - async def __call__(self, context: Contexts): - if "login" in context: - return context["login"] - if "$origin_event" not in context: - return - return context["$origin_event"].login - - def __repr__(self): - return f"<{self.__class__.__name__[:-5]}{self._origin!r}>" - - -class NoticeEvent(Event): - pass - - -class FriendEvent(NoticeEvent): - user: User = attr() - - -class FriendRequestEvent(FriendEvent): - type = EventType.FRIEND_REQUEST - - message: MessageObject = attr() - - -class GuildEvent(NoticeEvent): - guild: Guild = attr() - - async def gather(self, context: Contexts): - await super().gather(context) - context["guild"] = self.guild - - -class GuildAddedEvent(GuildEvent): - type = EventType.GUILD_ADDED - - -class GuildRemovedEvent(GuildEvent): - type = EventType.GUILD_REMOVED - - -class GuildRequestEvent(GuildEvent): - type = EventType.GUILD_REQUEST - - message: MessageObject = attr() - - -class GuildUpdatedEvent(GuildEvent): - type = EventType.GUILD_UPDATED - - -class GuildMemberEvent(GuildEvent): - user: User = attr() - - -class GuildMemberAddedEvent(GuildMemberEvent): - type = EventType.GUILD_MEMBER_ADDED - - -class GuildMemberRemovedEvent(GuildMemberEvent): - type = EventType.GUILD_MEMBER_REMOVED - - -class GuildMemberRequestEvent(GuildMemberEvent): - type = EventType.GUILD_MEMBER_REQUEST - - message: MessageObject = attr() - - -class GuildMemberUpdatedEvent(GuildMemberEvent): - type = EventType.GUILD_MEMBER_UPDATED - - -class GuildRoleEvent(GuildEvent): - role: Role = attr() - - -class GuildRoleCreatedEvent(GuildRoleEvent): - type = EventType.GUILD_ROLE_CREATED - - -class GuildRoleDeletedEvent(GuildRoleEvent): - type = EventType.GUILD_ROLE_DELETED - - -class GuildRoleUpdatedEvent(GuildRoleEvent): - type = EventType.GUILD_ROLE_UPDATED - - -class LoginEvent(NoticeEvent): - pass - - -class LoginAddedEvent(LoginEvent): - type = EventType.LOGIN_ADDED - - -class LoginRemovedEvent(LoginEvent): - type = EventType.LOGIN_REMOVED - - -class LoginUpdatedEvent(LoginEvent): - type = EventType.LOGIN_UPDATED - - -class MessageContentProvider(Provider[MessageChain]): - priority = 30 - - async def __call__(self, context: Contexts): - return context.get("$message_content") - - -class ReplyProvider(Provider[Reply]): - async def __call__(self, context: Contexts): - return context.get("$message_reply") - - -class MessageEvent(Event): - channel: Channel = attr() - user: User = attr() - message: MessageObject = attr() - - content: MessageChain - quote: Quote | None = None - - providers = [MessageContentProvider, ReplyProvider] - - def __init__(self, account: Account, origin: SatoriEvent): - super().__init__(account, origin) - self.content = MessageChain(self.message.message) - if self.content.has(Quote): - self.quote = self.content.get(Quote, 1)[0] - self.content = self.content.exclude(Quote) - - async def gather(self, context: Contexts): - await super().gather(context) - reply = None - if self.quote and self.quote.id: - mo = await self.account.protocol.message_get(self.channel.id, self.quote.id) - reply = context["$message_reply"] = Reply(self.quote, mo) - if not reply: - is_reply_me = False - else: - is_reply_me = _is_reply_me(reply, self.account) - context["is_reply_me"] = is_reply_me - if is_reply_me and self.content and isinstance(self.content[0], Text): - text = self.content[0].text.lstrip() - if not text: - self.content.pop(0) - else: - self.content[0] = Text(text) - is_notice_me = context["is_notice_me"] = _is_notice_me(self.content, self.account) - if is_notice_me: - self.content = _remove_notice_me(self.content, self.account) - context["$message_content"] = self.content - - -class MessageCreatedEvent(MessageEvent): - type = EventType.MESSAGE_CREATED - - -class MessageDeletedEvent(MessageEvent): - type = EventType.MESSAGE_DELETED - - -class MessageUpdatedEvent(MessageEvent): - type = EventType.MESSAGE_UPDATED - - -class ReactionEvent(NoticeEvent, MessageEvent): - pass - - -class ReactionAddedEvent(ReactionEvent): - type = EventType.REACTION_ADDED - - -class ReactionRemovedEvent(ReactionEvent): - type = EventType.REACTION_REMOVED - - -class InternalEvent(Event): - type = EventType.INTERNAL - - -class InteractionEvent(NoticeEvent): - pass - - -class InteractionButtonEvent(InteractionEvent): - type = EventType.INTERACTION_BUTTON - - button: ButtonInteraction = attr() - - class ButtonProvider(Provider[ButtonInteraction]): - async def __call__(self, context: Contexts): - return context.get("button") - - -class InteractionCommandEvent(InteractionEvent): - type = EventType.INTERACTION_COMMAND - - -class InteractionCommandArgvEvent(InteractionCommandEvent): - argv: ArgvInteraction = attr() - - class ArgvProvider(Provider[ArgvInteraction]): - async def __call__(self, context: Contexts): - return context.get("argv") - - -class InteractionCommandMessageEvent(InteractionCommandEvent, MessageEvent): - pass - - -MAPPING: dict[str, type[Event]] = {} - -for cls in gen_subclass(Event): - if hasattr(cls, "type"): - MAPPING[cls.type.value] = cls - - -def event_parse(account: Account, event: SatoriEvent): - try: - return MAPPING[event.type](account, event) - except KeyError: - raise NotImplementedError from None diff --git a/arclet/entari/event/send.py b/arclet/entari/event/send.py index c3a1ef7..a411b40 100644 --- a/arclet/entari/event/send.py +++ b/arclet/entari/event/send.py @@ -6,14 +6,13 @@ from satori.model import MessageReceipt from ..message import MessageChain -from .base import BasedEvent if TYPE_CHECKING: from ..session import Session @dataclass -class SendRequest(BasedEvent): +class SendRequest: account: Account channel: str message: MessageChain @@ -35,7 +34,7 @@ async def gather(self, context: Contexts): @dataclass -class SendResponse(BasedEvent): +class SendResponse: account: Account channel: str message: MessageChain diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py index 5b0e5b0..76f96f5 100644 --- a/arclet/entari/filter/common.py +++ b/arclet/entari/filter/common.py @@ -9,6 +9,7 @@ from satori.client import Account from tarina import is_async +from ..event.base import SatoriEvent from ..session import Session from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger from .op import ExcludeFilter, IntersectFilter, UnionFilter @@ -140,6 +141,8 @@ def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10 self.callback = None async def __call__(self, scope: Scope, interface: Interface): + if not isinstance(interface.event, SatoriEvent): # we only care about event from satori + return True for step in sorted(self.steps, key=lambda x: x.priority): if not await step(scope, interface): return False diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index 0ebca20..793b8e0 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -2,35 +2,32 @@ from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import Any, Callable from tarina import init_spec from ..config import EntariConfig from ..logger import log -from .model import Plugin from .model import PluginMetadata as PluginMetadata from .model import RegisterNotInPluginError from .model import RootlessPlugin as RootlessPlugin from .model import StaticPluginDispatchError, _current_plugin +from .model import TE, Plugin from .model import keeping as keeping from .module import import_plugin from .module import package as package from .module import requires as requires from .service import plugin_service -if TYPE_CHECKING: - from ..event.base import BasedEvent - -def dispatch(*events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None): +def dispatch(*events: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None): if not (plugin := _current_plugin.get(None)): raise LookupError("no plugin context found") return plugin.dispatch(*events, predicate=predicate, name=name) def load_plugin( - path: str, config: dict | None = None, recursive_guard: set[str] | None = None, static: bool = False + path: str, config: dict | None = None, recursive_guard: set[str] | None = None, prelude: bool = False ) -> Plugin | None: """ 以导入路径方式加载模块 @@ -39,7 +36,7 @@ def load_plugin( path (str): 模块路径 config (dict): 模块配置 recursive_guard (set[str]): 递归保护 - static (bool): 是否为静态插件 + prelude (bool): 是否为前置插件 """ if config is None: _config = EntariConfig.instance.plugin.get(path) @@ -58,7 +55,9 @@ def load_plugin( return plugin_service._apply[path](config or {}) try: conf = config or {} - if static: + if "$static" in conf: + del conf["$static"] + if prelude: conf["$static"] = True mod = import_plugin(path, config=conf) if not mod: @@ -119,11 +118,19 @@ def metadata(data: PluginMetadata): def plugin_config() -> dict[str, Any]: + """获取当前插件的配置""" if not (plugin := _current_plugin.get(None)): raise LookupError("no plugin context found") return plugin.config +def declare_static(): + """声明当前插件为静态插件""" + if not (plugin := _current_plugin.get(None)): + raise LookupError("no plugin context found") + plugin.is_static = True + + def find_plugin(name: str) -> Plugin | None: return plugin_service.plugins.get(name) diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index cfa007b..aab6cf3 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -20,12 +20,10 @@ from ..logger import log from .service import plugin_service -if TYPE_CHECKING: - from ..event.base import BasedEvent - _current_plugin: ContextModel[Plugin] = ContextModel("_current_plugin") T = TypeVar("T") +TE = TypeVar("TE") R = TypeVar("R") @@ -41,8 +39,8 @@ class PluginDispatcher: def __init__( self, plugin: Plugin, - *events: type[BasedEvent], - predicate: Callable[[BasedEvent], bool] | None = None, + *events: type[TE], + predicate: Callable[[TE], bool] | None = None, name: str | None = None, ): if len(events) == 1: @@ -61,7 +59,7 @@ def __init__( def waiter( self, - *events: type[BasedEvent], + *events: Any, providers: Sequence[Provider | type[Provider]] | None = None, auxiliaries: list[BaseAuxiliary] | None = None, priority: int = 15, @@ -237,9 +235,7 @@ def dispose(self): del plugin_service.plugins[self.id] del self.module - def dispatch( - self, *events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None - ): + def dispatch(self, *events: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None): if self.is_static: raise StaticPluginDispatchError("static plugin cannot dispatch events") disp = PluginDispatcher(self, *events, predicate=predicate, name=name) diff --git a/arclet/entari/scheduler.py b/arclet/entari/scheduler.py index 79c4945..cf5856e 100644 --- a/arclet/entari/scheduler.py +++ b/arclet/entari/scheduler.py @@ -1,4 +1,5 @@ import asyncio +from asyncio.events import _get_running_loop # type: ignore from datetime import datetime, timedelta from traceback import print_exc from typing import Callable, Literal @@ -9,7 +10,7 @@ from launart import Launart, Service, any_completed from launart.status import Phase -from .plugin import RootlessPlugin +from .plugin import RootlessPlugin, _current_plugin class _ScheduleEvent: @@ -28,7 +29,7 @@ def __init__(self, supplier: Callable[[], timedelta], sub_id: str): self.handle = None def start(self, queue: asyncio.Queue): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() self.handle = loop.call_later(self.supplier().total_seconds(), queue.put_nowait, self) def cancel(self): @@ -63,11 +64,16 @@ async def fetch(self): except Exception: print_exc() - def schedule(self, timer: Callable[[], timedelta]): + def schedule(self, timer: Callable[[], timedelta], once: bool = False): def wrapper(func: Callable): - sub = pub.register(func) + if plugin := _current_plugin.get(): + sub = plugin.dispatch(_ScheduleEvent).register(func, temporary=once) + else: + sub = pub.register(func, temporary=once) self.timers[sub.id] = TimerTask(timer, sub.id) + if _get_running_loop(): + self.timers[sub.id].start(self.queue) return sub return wrapper @@ -205,3 +211,8 @@ def every( "hour": every_hours, } return service.schedule(_TIMER_MAPPING[mode](value)) + + +def invoke(delay: float): + """延迟执行""" + return service.schedule(lambda: timedelta(seconds=delay), once=True) diff --git a/arclet/entari/session.py b/arclet/entari/session.py index a954635..4ed36e0 100644 --- a/arclet/entari/session.py +++ b/arclet/entari/session.py @@ -10,24 +10,24 @@ from satori.element import Element from satori.model import Channel, Guild, Member, MessageReceipt, PageResult, Role, User -from .event.protocol import Event, FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent, Reply +from .event.base import FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent, Reply, SatoriEvent from .event.send import SendRequest, SendResponse from .message import MessageChain -TEvent = TypeVar("TEvent", bound=Event) +TEvent = TypeVar("TEvent", bound=SatoriEvent) class EntariProtocol(ApiProtocol): async def send_message( - self, channel: str | Channel, message: str | Iterable[str | Element], source: Event | None = None + self, channel: str | Channel, message: str | Iterable[str | Element], source: SatoriEvent | None = None ) -> list[MessageReceipt]: """发送消息。返回一个 `MessageReceipt` 对象构成的数组。 Args: channel (str | Channel): 要发送的频道 ID message (str | Iterable[str | Element]): 要发送的消息 - source (Event | None): 源事件 + source (SatoriEvent | None): 源事件 Returns: list[MessageReceipt]: `MessageReceipt` 对象构成的数组 @@ -36,14 +36,14 @@ async def send_message( return await self.message_create(channel_id=channel_id, content=message, source=source) async def send_private_message( - self, user: str | User, message: str | Iterable[str | Element], source: Event | None = None + self, user: str | User, message: str | Iterable[str | Element], source: SatoriEvent | None = None ) -> list[MessageReceipt]: """发送私聊消息。返回一个 `MessageReceipt` 对象构成的数组。 Args: user (str | User): 要发送的用户 ID message (str | Iterable[str | Element]): 要发送的消息 - source (Event | None): 源事件 + source (SatoriEvent | None): 源事件 Returns: list[MessageReceipt]: `MessageReceipt` 对象构成的数组 @@ -53,14 +53,14 @@ async def send_private_message( return await self.message_create(channel_id=channel.id, content=message, source=source) async def message_create( - self, channel_id: str, content: str | Iterable[str | Element], source: Event | None = None + self, channel_id: str, content: str | Iterable[str | Element], source: SatoriEvent | None = None ) -> list[MessageReceipt]: """发送消息。返回一个 `MessageReceipt` 对象构成的数组。 Args: channel_id (str): 频道 ID content (str | Iterable[str | Element]): 消息内容 - source (Event | None): 源事件 + source (SatoriEvent | None): 源事件 Returns: list[MessageReceipt]: `MessageReceipt` 对象构成的数组 @@ -242,14 +242,12 @@ async def message_create( raise RuntimeError("Event cannot be replied to!") return await self.account.protocol.send_message(self.context.channel.id, content, self.context) - async def message_delete(self) -> None: + async def message_delete(self, message_id: str) -> None: if not self.context.channel: raise RuntimeError("Event cannot be replied to!") - if not self.context.message: - raise RuntimeError("Event cannot update message") await self.account.protocol.message_delete( self.context.channel.id, - self.context.message.id, + message_id, ) async def message_update( diff --git a/example_plugin.py b/example_plugin.py index edc3654..a247a24 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -10,7 +10,7 @@ metadata, keeping, scheduler, - Entari, + # Entari, ) from arclet.entari.filter import Interval @@ -29,7 +29,7 @@ async def cleanup(): print("example: Cleanup") -disp_message = MessageCreatedEvent.dispatch() +disp_message = plug.dispatch(MessageCreatedEvent) @disp_message @@ -38,17 +38,20 @@ async def _(msg: MessageChain, session: Session): content = msg.extract_plain_text() if re.match(r"(.{0,3})(上传|设定)(.{0,3})(上传|设定)(.{0,3})", content): return await session.send("上传设定的帮助是...") + if content == "test": + resp = await session.send("This message will recall in 5s...") + @scheduler.invoke(5) + async def _(): + await session.message_delete(resp[0].id) -disp_message1 = plug.dispatch(MessageCreatedEvent) - -@disp_message1.on(auxiliaries=[Filter().public().to_me().and_(lambda sess: str(sess.content) == "aaa")]) +@disp_message.on(auxiliaries=[Filter().public().to_me().and_(lambda sess: str(sess.content) == "aaa")]) async def _(session: Session): return await session.send("Filter: public message, to me, and content is 'aaa'") -@disp_message1.on(auxiliaries=[Filter().public().to_me().not_(lambda sess: str(sess.content) == "aaa")]) +@disp_message.on(auxiliaries=[Filter().public().to_me().not_(lambda sess: str(sess.content) == "aaa")]) async def _(session: Session): return await session.send("Filter: public message, to me, but content is not 'aaa'") @@ -84,10 +87,9 @@ async def show(session: Session): async def send_hook(message: MessageChain): return message + "喵" - -@scheduler.cron("* * * * *") -async def broadcast(app: Entari): - for account in app.accounts.values(): - channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data] - for channel in channels: - await account.send_message(channel, "Hello, World!") +# @scheduler.cron("* * * * *") +# async def broadcast(app: Entari): +# for account in app.accounts.values(): +# channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data] +# for channel in channels: +# await account.send_message(channel, "Hello, World!")