diff --git a/README.md b/README.md
index 0e719c7..867a007 100644
--- a/README.md
+++ b/README.md
@@ -85,7 +85,7 @@ app.run()
编写插件:
```python
-from arclet.entari import Session, MessageCreatedEvent, metadata
+from arclet.entari import Session, MessageCreatedEvent, metadata, plugin
metadata(
name="Hello, World!",
@@ -95,7 +95,7 @@ metadata(
)
# or __plugin_metadata__ = PluginMetadata(...)
-@MessageCreatedEvent.dispatch()
+@plugin.dispatch(MessageCreatedEvent)
async def _(session: Session):
await session.send("Hello, World!")
```
diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py
index 1d580fc..a04a9eb 100644
--- a/arclet/entari/__init__.py
+++ b/arclet/entari/__init__.py
@@ -49,6 +49,7 @@
from .message import MessageChain as MessageChain
from .plugin import Plugin as Plugin
from .plugin import PluginMetadata as PluginMetadata
+from .plugin import declare_static as declare_static
from .plugin import dispose as dispose_plugin # noqa: F401
from .plugin import keeping as keeping
from .plugin import load_plugin as load_plugin
diff --git a/arclet/entari/builtins/auto_reload.py b/arclet/entari/builtins/auto_reload.py
index 64eb8bc..05f70b7 100644
--- a/arclet/entari/builtins/auto_reload.py
+++ b/arclet/entari/builtins/auto_reload.py
@@ -7,12 +7,14 @@
from launart.status import Phase
from watchfiles import PythonFilter, awatch
-from arclet.entari import Plugin, dispose_plugin, load_plugin, metadata
+from arclet.entari import Plugin, declare_static, dispose_plugin, load_plugin, metadata
from arclet.entari.config import EntariConfig
from arclet.entari.event.config import ConfigReload
from arclet.entari.logger import log
from arclet.entari.plugin import find_plugin, find_plugin_by_file
+declare_static()
+
class Config:
watch_dirs: list[str] = ["."]
@@ -66,7 +68,7 @@ async def watch(self):
async for event in awatch(*self.dirs, watch_filter=PythonFilter()):
for change in event:
if plugin := find_plugin_by_file(change[1]):
- if plugin.id == "arclet.entari.builtins.auto_reload" or plugin.is_static:
+ if plugin.is_static:
logger("INFO", f"Plugin {plugin.id!r} is static, ignored.")
continue
logger("INFO", f"Detected change in {plugin.id!r}, reloading...")
diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py
index 61e4852..c687e69 100644
--- a/arclet/entari/command/__init__.py
+++ b/arclet/entari/command/__init__.py
@@ -14,9 +14,9 @@
from tarina.string import split
from tarina.trie import CharTrie
+from ..event.base import MessageCreatedEvent
from ..event.command import CommandExecute
from ..event.config import ConfigReload
-from ..event.protocol import MessageCreatedEvent
from ..message import MessageChain
from ..plugin import RootlessPlugin
from ..session import Session
diff --git a/arclet/entari/core.py b/arclet/entari/core.py
index 57cef3e..91062b8 100644
--- a/arclet/entari/core.py
+++ b/arclet/entari/core.py
@@ -15,9 +15,9 @@
from tarina.generic import get_origin
from .config import EntariConfig
+from .event.base import MessageCreatedEvent, event_parse
from .event.config import ConfigReload
from .event.lifespan import AccountUpdate
-from .event.protocol import MessageCreatedEvent, event_parse
from .event.send import SendResponse
from .logger import log
from .plugin import load_plugin, plugin_config, requires
@@ -130,7 +130,7 @@ def __init__(
log.core.opt(colors=True).debug(f"Log level set to {log_level}")
requires(*EntariConfig.instance.prelude_plugin)
for plug in EntariConfig.instance.prelude_plugin:
- load_plugin(plug, static=True)
+ load_plugin(plug, prelude=True)
requires(*EntariConfig.instance.plugin)
for plug in EntariConfig.instance.plugin:
load_plugin(plug)
diff --git a/arclet/entari/event/__init__.py b/arclet/entari/event/__init__.py
index 7ff6021..a462a0d 100644
--- a/arclet/entari/event/__init__.py
+++ b/arclet/entari/event/__init__.py
@@ -1,2 +1,2 @@
-from .protocol import MessageCreatedEvent as MessageCreatedEvent
-from .protocol import MessageEvent as MessageEvent
+from .base import MessageCreatedEvent as MessageCreatedEvent
+from .base import MessageEvent as MessageEvent
diff --git a/arclet/entari/event/base.py b/arclet/entari/event/base.py
index ab06489..4130e0f 100644
--- a/arclet/entari/event/base.py
+++ b/arclet/entari/event/base.py
@@ -1,14 +1,403 @@
from __future__ import annotations
-from typing import Callable, TypeVar
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, ClassVar, Generic, TypeVar
-TE = TypeVar("TE", bound="BasedEvent")
+from arclet.letoderea import Contexts, Param, Provider
+from satori import ArgvInteraction, ButtonInteraction, Channel
+from satori import Event as OriginEvent
+from satori import EventType, Guild, Member, Role, User
+from satori.client import Account
+from satori.element import At, Author, Quote, Text, select
+from satori.model import Login, MessageObject
+from tarina import gen_subclass
+from ..message import MessageChain
-class BasedEvent:
- @classmethod
- def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None):
- from ..plugin import dispatch
+T = TypeVar("T")
+D = TypeVar("D")
- name = name or getattr(cls, "__publisher__", None)
- return dispatch(cls, predicate=predicate, name=name) # type: ignore
+
+@dataclass
+class Reply:
+ quote: Quote
+ origin: MessageObject
+
+
+def _is_reply_me(reply: Reply, account: Account):
+ if reply.origin.user:
+ return reply.origin.user.id == account.self_id
+ if authors := select(reply.quote, Author):
+ return any(author.id == account.self_id for author in authors)
+ return False
+
+
+def _is_notice_me(message: MessageChain, account: Account):
+ if message and isinstance(message[0], At):
+ at: At = message[0] # type: ignore
+ if at.id and at.id == account.self_id:
+ return True
+ return False
+
+
+def _remove_notice_me(message: MessageChain, account: Account):
+ message = message.copy()
+ message.pop(0)
+ if _is_notice_me(message, account):
+ message.pop(0)
+ if message and isinstance(message[0], Text):
+ text = message[0].text.lstrip() # type: ignore
+ if not text:
+ message.pop(0)
+ else:
+ message[0] = Text(text)
+ return message
+
+
+class Attr(Generic[T]):
+ def __init__(self, key: str | None = None):
+ self.key = key
+
+ def __set_name__(self, owner: type[SatoriEvent], name: str):
+ self.key = self.key or name
+ if name not in ("id", "timestamp"):
+ owner._attrs.add(name)
+
+ def __get__(self, instance: SatoriEvent, owner: type[SatoriEvent]) -> T:
+ return getattr(instance._origin, self.key, None) # type: ignore
+
+ def __set__(self, instance: SatoriEvent, value):
+ raise AttributeError("can't set attribute")
+
+
+def attr(key: str | None = None) -> Any:
+ return Attr(key)
+
+
+class SatoriEvent:
+ type: ClassVar[EventType]
+ _attrs: ClassVar[set[str]] = set()
+ _origin: OriginEvent
+ account: Account
+
+ sn: int = attr()
+ timestamp: datetime = attr()
+ login: Login = attr()
+ argv: ArgvInteraction | None = attr()
+ button: ButtonInteraction | None = attr()
+ channel: Channel | None = attr()
+ guild: Guild | None = attr()
+ member: Member | None = attr()
+ message: MessageObject | None = attr()
+ operator: User | None = attr()
+ role: Role | None = attr()
+ user: User | None = attr()
+
+ def __init__(self, account: Account, origin: OriginEvent):
+ self.account = account
+ self._origin = origin
+
+ async def gather(self, context: Contexts):
+ context["account"] = self.account
+ context["$origin_event"] = self._origin
+
+ for name in self.__class__._attrs:
+ value = getattr(self, name)
+ if value is not None:
+ context["$message_origin" if name == "message" else name] = value
+
+ class TimeProvider(Provider[datetime]):
+ async def __call__(self, context: Contexts):
+ if "$event" in context:
+ return context["$event"].timestamp
+
+ class OperatorProvider(Provider[User]):
+ priority = 10
+
+ def validate(self, param: Param):
+ return param.name == "operator" and super().validate(param)
+
+ async def __call__(self, context: Contexts):
+ if "operator" in context:
+ return context["operator"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].operator
+
+ class UserProvider(Provider[User]):
+ async def __call__(self, context: Contexts):
+ if "user" in context:
+ return context["user"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].user
+
+ class MessageProvider(Provider[MessageObject]):
+ async def __call__(self, context: Contexts):
+ if "$message_origin" in context:
+ return context["$message_origin"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].message
+
+ class ChannelProvider(Provider[Channel]):
+ async def __call__(self, context: Contexts):
+ if "channel" in context:
+ return context["channel"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].channel
+
+ class GuildProvider(Provider[Guild]):
+ async def __call__(self, context: Contexts):
+ if "guild" in context:
+ return context["guild"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].guild
+
+ class MemberProvider(Provider[Member]):
+ async def __call__(self, context: Contexts):
+ if "member" in context:
+ return context["member"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].member
+
+ class RoleProvider(Provider[Role]):
+ async def __call__(self, context: Contexts):
+ if "role" in context:
+ return context["role"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].role
+
+ class LoginProvider(Provider[Login]):
+ async def __call__(self, context: Contexts):
+ if "login" in context:
+ return context["login"]
+ if "$origin_event" not in context:
+ return
+ return context["$origin_event"].login
+
+ def __repr__(self):
+ return f"<{self.__class__.__name__[:-5]}{self._origin!r}>"
+
+
+class NoticeEvent(SatoriEvent):
+ pass
+
+
+class FriendEvent(NoticeEvent):
+ user: User = attr()
+
+
+class FriendRequestEvent(FriendEvent):
+ type = EventType.FRIEND_REQUEST
+
+ message: MessageObject = attr()
+
+
+class GuildEvent(NoticeEvent):
+ guild: Guild = attr()
+
+ async def gather(self, context: Contexts):
+ await super().gather(context)
+ context["guild"] = self.guild
+
+
+class GuildAddedEvent(GuildEvent):
+ type = EventType.GUILD_ADDED
+
+
+class GuildRemovedEvent(GuildEvent):
+ type = EventType.GUILD_REMOVED
+
+
+class GuildRequestEvent(GuildEvent):
+ type = EventType.GUILD_REQUEST
+
+ message: MessageObject = attr()
+
+
+class GuildUpdatedEvent(GuildEvent):
+ type = EventType.GUILD_UPDATED
+
+
+class GuildMemberEvent(GuildEvent):
+ user: User = attr()
+
+
+class GuildMemberAddedEvent(GuildMemberEvent):
+ type = EventType.GUILD_MEMBER_ADDED
+
+
+class GuildMemberRemovedEvent(GuildMemberEvent):
+ type = EventType.GUILD_MEMBER_REMOVED
+
+
+class GuildMemberRequestEvent(GuildMemberEvent):
+ type = EventType.GUILD_MEMBER_REQUEST
+
+ message: MessageObject = attr()
+
+
+class GuildMemberUpdatedEvent(GuildMemberEvent):
+ type = EventType.GUILD_MEMBER_UPDATED
+
+
+class GuildRoleEvent(GuildEvent):
+ role: Role = attr()
+
+
+class GuildRoleCreatedEvent(GuildRoleEvent):
+ type = EventType.GUILD_ROLE_CREATED
+
+
+class GuildRoleDeletedEvent(GuildRoleEvent):
+ type = EventType.GUILD_ROLE_DELETED
+
+
+class GuildRoleUpdatedEvent(GuildRoleEvent):
+ type = EventType.GUILD_ROLE_UPDATED
+
+
+class LoginEvent(NoticeEvent):
+ pass
+
+
+class LoginAddedEvent(LoginEvent):
+ type = EventType.LOGIN_ADDED
+
+
+class LoginRemovedEvent(LoginEvent):
+ type = EventType.LOGIN_REMOVED
+
+
+class LoginUpdatedEvent(LoginEvent):
+ type = EventType.LOGIN_UPDATED
+
+
+class MessageContentProvider(Provider[MessageChain]):
+ priority = 30
+
+ async def __call__(self, context: Contexts):
+ return context.get("$message_content")
+
+
+class ReplyProvider(Provider[Reply]):
+ async def __call__(self, context: Contexts):
+ return context.get("$message_reply")
+
+
+class MessageEvent(SatoriEvent):
+ channel: Channel = attr()
+ user: User = attr()
+ message: MessageObject = attr()
+
+ content: MessageChain
+ quote: Quote | None = None
+
+ providers = [MessageContentProvider, ReplyProvider]
+
+ def __init__(self, account: Account, origin: OriginEvent):
+ super().__init__(account, origin)
+ self.content = MessageChain(self.message.message)
+ if self.content.has(Quote):
+ self.quote = self.content.get(Quote, 1)[0]
+ self.content = self.content.exclude(Quote)
+
+ async def gather(self, context: Contexts):
+ await super().gather(context)
+ reply = None
+ if self.quote and self.quote.id:
+ mo = await self.account.protocol.message_get(self.channel.id, self.quote.id)
+ reply = context["$message_reply"] = Reply(self.quote, mo)
+ if not reply:
+ is_reply_me = False
+ else:
+ is_reply_me = _is_reply_me(reply, self.account)
+ context["is_reply_me"] = is_reply_me
+ if is_reply_me and self.content and isinstance(self.content[0], Text):
+ text = self.content[0].text.lstrip()
+ if not text:
+ self.content.pop(0)
+ else:
+ self.content[0] = Text(text)
+ is_notice_me = context["is_notice_me"] = _is_notice_me(self.content, self.account)
+ if is_notice_me:
+ self.content = _remove_notice_me(self.content, self.account)
+ context["$message_content"] = self.content
+
+
+class MessageCreatedEvent(MessageEvent):
+ type = EventType.MESSAGE_CREATED
+
+
+class MessageDeletedEvent(MessageEvent):
+ type = EventType.MESSAGE_DELETED
+
+
+class MessageUpdatedEvent(MessageEvent):
+ type = EventType.MESSAGE_UPDATED
+
+
+class ReactionEvent(NoticeEvent, MessageEvent):
+ pass
+
+
+class ReactionAddedEvent(ReactionEvent):
+ type = EventType.REACTION_ADDED
+
+
+class ReactionRemovedEvent(ReactionEvent):
+ type = EventType.REACTION_REMOVED
+
+
+class InternalEvent(SatoriEvent):
+ type = EventType.INTERNAL
+
+
+class InteractionEvent(NoticeEvent):
+ pass
+
+
+class InteractionButtonEvent(InteractionEvent):
+ type = EventType.INTERACTION_BUTTON
+
+ button: ButtonInteraction = attr()
+
+ class ButtonProvider(Provider[ButtonInteraction]):
+ async def __call__(self, context: Contexts):
+ return context.get("button")
+
+
+class InteractionCommandEvent(InteractionEvent):
+ type = EventType.INTERACTION_COMMAND
+
+
+class InteractionCommandArgvEvent(InteractionCommandEvent):
+ argv: ArgvInteraction = attr()
+
+ class ArgvProvider(Provider[ArgvInteraction]):
+ async def __call__(self, context: Contexts):
+ return context.get("argv")
+
+
+class InteractionCommandMessageEvent(InteractionCommandEvent, MessageEvent):
+ pass
+
+
+MAPPING: dict[str, type[SatoriEvent]] = {}
+
+for cls in gen_subclass(SatoriEvent):
+ if hasattr(cls, "type"):
+ MAPPING[cls.type.value] = cls
+
+
+def event_parse(account: Account, event: OriginEvent):
+ try:
+ return MAPPING[event.type](account, event)
+ except KeyError:
+ raise NotImplementedError from None
diff --git a/arclet/entari/event/command.py b/arclet/entari/event/command.py
index 3621481..e50aa97 100644
--- a/arclet/entari/event/command.py
+++ b/arclet/entari/event/command.py
@@ -4,11 +4,10 @@
from arclet.letoderea import Contexts, Provider, es
from ..message import MessageChain
-from .base import BasedEvent
@dataclass
-class CommandExecute(BasedEvent):
+class CommandExecute:
command: Union[str, MessageChain]
async def gather(self, context: Contexts):
diff --git a/arclet/entari/event/config.py b/arclet/entari/event/config.py
index df75c9c..1e2a60b 100644
--- a/arclet/entari/event/config.py
+++ b/arclet/entari/event/config.py
@@ -3,11 +3,9 @@
from arclet.letoderea import es
-from .base import BasedEvent
-
@dataclass
-class ConfigReload(BasedEvent):
+class ConfigReload:
scope: str
key: str
value: Any
diff --git a/arclet/entari/event/lifespan.py b/arclet/entari/event/lifespan.py
index 9af356e..4a9860c 100644
--- a/arclet/entari/event/lifespan.py
+++ b/arclet/entari/event/lifespan.py
@@ -4,32 +4,30 @@
from satori.client import Account
from satori.model import LoginStatus
-from .base import BasedEvent
-
@dataclass
-class Startup(BasedEvent):
+class Startup:
pass
__publisher__ = "entari.event/startup"
@dataclass
-class Ready(BasedEvent):
+class Ready:
pass
__publisher__ = "entari.event/ready"
@dataclass
-class Cleanup(BasedEvent):
+class Cleanup:
pass
__publisher__ = "entari.event/cleanup"
@dataclass
-class AccountUpdate(BasedEvent):
+class AccountUpdate:
account: Account
status: LoginStatus
diff --git a/arclet/entari/event/protocol.py b/arclet/entari/event/protocol.py
deleted file mode 100644
index e423947..0000000
--- a/arclet/entari/event/protocol.py
+++ /dev/null
@@ -1,404 +0,0 @@
-from __future__ import annotations
-
-from dataclasses import dataclass
-from datetime import datetime
-from typing import Any, ClassVar, Generic, TypeVar
-
-from arclet.letoderea import Contexts, Param, Provider
-from satori import ArgvInteraction, ButtonInteraction, Channel
-from satori import Event as SatoriEvent
-from satori import EventType, Guild, Member, Role, User
-from satori.client import Account
-from satori.element import At, Author, Quote, Text, select
-from satori.model import Login, MessageObject
-from tarina import gen_subclass
-
-from ..message import MessageChain
-from .base import BasedEvent
-
-T = TypeVar("T")
-D = TypeVar("D")
-
-
-@dataclass
-class Reply:
- quote: Quote
- origin: MessageObject
-
-
-def _is_reply_me(reply: Reply, account: Account):
- if reply.origin.user:
- return reply.origin.user.id == account.self_id
- if authors := select(reply.quote, Author):
- return any(author.id == account.self_id for author in authors)
- return False
-
-
-def _is_notice_me(message: MessageChain, account: Account):
- if message and isinstance(message[0], At):
- at: At = message[0] # type: ignore
- if at.id and at.id == account.self_id:
- return True
- return False
-
-
-def _remove_notice_me(message: MessageChain, account: Account):
- message = message.copy()
- message.pop(0)
- if _is_notice_me(message, account):
- message.pop(0)
- if message and isinstance(message[0], Text):
- text = message[0].text.lstrip() # type: ignore
- if not text:
- message.pop(0)
- else:
- message[0] = Text(text)
- return message
-
-
-class Attr(Generic[T]):
- def __init__(self, key: str | None = None):
- self.key = key
-
- def __set_name__(self, owner: type[Event], name: str):
- self.key = self.key or name
- if name not in ("id", "timestamp"):
- owner._attrs.add(name)
-
- def __get__(self, instance: Event, owner: type[Event]) -> T:
- return getattr(instance._origin, self.key, None) # type: ignore
-
- def __set__(self, instance: Event, value):
- raise AttributeError("can't set attribute")
-
-
-def attr(key: str | None = None) -> Any:
- return Attr(key)
-
-
-class Event(BasedEvent):
- type: ClassVar[EventType]
- _attrs: ClassVar[set[str]] = set()
- _origin: SatoriEvent
- account: Account
-
- sn: int = attr()
- timestamp: datetime = attr()
- login: Login = attr()
- argv: ArgvInteraction | None = attr()
- button: ButtonInteraction | None = attr()
- channel: Channel | None = attr()
- guild: Guild | None = attr()
- member: Member | None = attr()
- message: MessageObject | None = attr()
- operator: User | None = attr()
- role: Role | None = attr()
- user: User | None = attr()
-
- def __init__(self, account: Account, origin: SatoriEvent):
- self.account = account
- self._origin = origin
-
- async def gather(self, context: Contexts):
- context["account"] = self.account
- context["$origin_event"] = self._origin
-
- for name in self.__class__._attrs:
- value = getattr(self, name)
- if value is not None:
- context["$message_origin" if name == "message" else name] = value
-
- class TimeProvider(Provider[datetime]):
- async def __call__(self, context: Contexts):
- if "$event" in context:
- return context["$event"].timestamp
-
- class OperatorProvider(Provider[User]):
- priority = 10
-
- def validate(self, param: Param):
- return param.name == "operator" and super().validate(param)
-
- async def __call__(self, context: Contexts):
- if "operator" in context:
- return context["operator"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].operator
-
- class UserProvider(Provider[User]):
- async def __call__(self, context: Contexts):
- if "user" in context:
- return context["user"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].user
-
- class MessageProvider(Provider[MessageObject]):
- async def __call__(self, context: Contexts):
- if "$message_origin" in context:
- return context["$message_origin"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].message
-
- class ChannelProvider(Provider[Channel]):
- async def __call__(self, context: Contexts):
- if "channel" in context:
- return context["channel"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].channel
-
- class GuildProvider(Provider[Guild]):
- async def __call__(self, context: Contexts):
- if "guild" in context:
- return context["guild"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].guild
-
- class MemberProvider(Provider[Member]):
- async def __call__(self, context: Contexts):
- if "member" in context:
- return context["member"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].member
-
- class RoleProvider(Provider[Role]):
- async def __call__(self, context: Contexts):
- if "role" in context:
- return context["role"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].role
-
- class LoginProvider(Provider[Login]):
- async def __call__(self, context: Contexts):
- if "login" in context:
- return context["login"]
- if "$origin_event" not in context:
- return
- return context["$origin_event"].login
-
- def __repr__(self):
- return f"<{self.__class__.__name__[:-5]}{self._origin!r}>"
-
-
-class NoticeEvent(Event):
- pass
-
-
-class FriendEvent(NoticeEvent):
- user: User = attr()
-
-
-class FriendRequestEvent(FriendEvent):
- type = EventType.FRIEND_REQUEST
-
- message: MessageObject = attr()
-
-
-class GuildEvent(NoticeEvent):
- guild: Guild = attr()
-
- async def gather(self, context: Contexts):
- await super().gather(context)
- context["guild"] = self.guild
-
-
-class GuildAddedEvent(GuildEvent):
- type = EventType.GUILD_ADDED
-
-
-class GuildRemovedEvent(GuildEvent):
- type = EventType.GUILD_REMOVED
-
-
-class GuildRequestEvent(GuildEvent):
- type = EventType.GUILD_REQUEST
-
- message: MessageObject = attr()
-
-
-class GuildUpdatedEvent(GuildEvent):
- type = EventType.GUILD_UPDATED
-
-
-class GuildMemberEvent(GuildEvent):
- user: User = attr()
-
-
-class GuildMemberAddedEvent(GuildMemberEvent):
- type = EventType.GUILD_MEMBER_ADDED
-
-
-class GuildMemberRemovedEvent(GuildMemberEvent):
- type = EventType.GUILD_MEMBER_REMOVED
-
-
-class GuildMemberRequestEvent(GuildMemberEvent):
- type = EventType.GUILD_MEMBER_REQUEST
-
- message: MessageObject = attr()
-
-
-class GuildMemberUpdatedEvent(GuildMemberEvent):
- type = EventType.GUILD_MEMBER_UPDATED
-
-
-class GuildRoleEvent(GuildEvent):
- role: Role = attr()
-
-
-class GuildRoleCreatedEvent(GuildRoleEvent):
- type = EventType.GUILD_ROLE_CREATED
-
-
-class GuildRoleDeletedEvent(GuildRoleEvent):
- type = EventType.GUILD_ROLE_DELETED
-
-
-class GuildRoleUpdatedEvent(GuildRoleEvent):
- type = EventType.GUILD_ROLE_UPDATED
-
-
-class LoginEvent(NoticeEvent):
- pass
-
-
-class LoginAddedEvent(LoginEvent):
- type = EventType.LOGIN_ADDED
-
-
-class LoginRemovedEvent(LoginEvent):
- type = EventType.LOGIN_REMOVED
-
-
-class LoginUpdatedEvent(LoginEvent):
- type = EventType.LOGIN_UPDATED
-
-
-class MessageContentProvider(Provider[MessageChain]):
- priority = 30
-
- async def __call__(self, context: Contexts):
- return context.get("$message_content")
-
-
-class ReplyProvider(Provider[Reply]):
- async def __call__(self, context: Contexts):
- return context.get("$message_reply")
-
-
-class MessageEvent(Event):
- channel: Channel = attr()
- user: User = attr()
- message: MessageObject = attr()
-
- content: MessageChain
- quote: Quote | None = None
-
- providers = [MessageContentProvider, ReplyProvider]
-
- def __init__(self, account: Account, origin: SatoriEvent):
- super().__init__(account, origin)
- self.content = MessageChain(self.message.message)
- if self.content.has(Quote):
- self.quote = self.content.get(Quote, 1)[0]
- self.content = self.content.exclude(Quote)
-
- async def gather(self, context: Contexts):
- await super().gather(context)
- reply = None
- if self.quote and self.quote.id:
- mo = await self.account.protocol.message_get(self.channel.id, self.quote.id)
- reply = context["$message_reply"] = Reply(self.quote, mo)
- if not reply:
- is_reply_me = False
- else:
- is_reply_me = _is_reply_me(reply, self.account)
- context["is_reply_me"] = is_reply_me
- if is_reply_me and self.content and isinstance(self.content[0], Text):
- text = self.content[0].text.lstrip()
- if not text:
- self.content.pop(0)
- else:
- self.content[0] = Text(text)
- is_notice_me = context["is_notice_me"] = _is_notice_me(self.content, self.account)
- if is_notice_me:
- self.content = _remove_notice_me(self.content, self.account)
- context["$message_content"] = self.content
-
-
-class MessageCreatedEvent(MessageEvent):
- type = EventType.MESSAGE_CREATED
-
-
-class MessageDeletedEvent(MessageEvent):
- type = EventType.MESSAGE_DELETED
-
-
-class MessageUpdatedEvent(MessageEvent):
- type = EventType.MESSAGE_UPDATED
-
-
-class ReactionEvent(NoticeEvent, MessageEvent):
- pass
-
-
-class ReactionAddedEvent(ReactionEvent):
- type = EventType.REACTION_ADDED
-
-
-class ReactionRemovedEvent(ReactionEvent):
- type = EventType.REACTION_REMOVED
-
-
-class InternalEvent(Event):
- type = EventType.INTERNAL
-
-
-class InteractionEvent(NoticeEvent):
- pass
-
-
-class InteractionButtonEvent(InteractionEvent):
- type = EventType.INTERACTION_BUTTON
-
- button: ButtonInteraction = attr()
-
- class ButtonProvider(Provider[ButtonInteraction]):
- async def __call__(self, context: Contexts):
- return context.get("button")
-
-
-class InteractionCommandEvent(InteractionEvent):
- type = EventType.INTERACTION_COMMAND
-
-
-class InteractionCommandArgvEvent(InteractionCommandEvent):
- argv: ArgvInteraction = attr()
-
- class ArgvProvider(Provider[ArgvInteraction]):
- async def __call__(self, context: Contexts):
- return context.get("argv")
-
-
-class InteractionCommandMessageEvent(InteractionCommandEvent, MessageEvent):
- pass
-
-
-MAPPING: dict[str, type[Event]] = {}
-
-for cls in gen_subclass(Event):
- if hasattr(cls, "type"):
- MAPPING[cls.type.value] = cls
-
-
-def event_parse(account: Account, event: SatoriEvent):
- try:
- return MAPPING[event.type](account, event)
- except KeyError:
- raise NotImplementedError from None
diff --git a/arclet/entari/event/send.py b/arclet/entari/event/send.py
index c3a1ef7..a411b40 100644
--- a/arclet/entari/event/send.py
+++ b/arclet/entari/event/send.py
@@ -6,14 +6,13 @@
from satori.model import MessageReceipt
from ..message import MessageChain
-from .base import BasedEvent
if TYPE_CHECKING:
from ..session import Session
@dataclass
-class SendRequest(BasedEvent):
+class SendRequest:
account: Account
channel: str
message: MessageChain
@@ -35,7 +34,7 @@ async def gather(self, context: Contexts):
@dataclass
-class SendResponse(BasedEvent):
+class SendResponse:
account: Account
channel: str
message: MessageChain
diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py
index 5b0e5b0..76f96f5 100644
--- a/arclet/entari/filter/common.py
+++ b/arclet/entari/filter/common.py
@@ -9,6 +9,7 @@
from satori.client import Account
from tarina import is_async
+from ..event.base import SatoriEvent
from ..session import Session
from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger
from .op import ExcludeFilter, IntersectFilter, UnionFilter
@@ -140,6 +141,8 @@ def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10
self.callback = None
async def __call__(self, scope: Scope, interface: Interface):
+ if not isinstance(interface.event, SatoriEvent): # we only care about event from satori
+ return True
for step in sorted(self.steps, key=lambda x: x.priority):
if not await step(scope, interface):
return False
diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py
index 0ebca20..793b8e0 100644
--- a/arclet/entari/plugin/__init__.py
+++ b/arclet/entari/plugin/__init__.py
@@ -2,35 +2,32 @@
from os import PathLike
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable
+from typing import Any, Callable
from tarina import init_spec
from ..config import EntariConfig
from ..logger import log
-from .model import Plugin
from .model import PluginMetadata as PluginMetadata
from .model import RegisterNotInPluginError
from .model import RootlessPlugin as RootlessPlugin
from .model import StaticPluginDispatchError, _current_plugin
+from .model import TE, Plugin
from .model import keeping as keeping
from .module import import_plugin
from .module import package as package
from .module import requires as requires
from .service import plugin_service
-if TYPE_CHECKING:
- from ..event.base import BasedEvent
-
-def dispatch(*events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None):
+def dispatch(*events: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None):
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
return plugin.dispatch(*events, predicate=predicate, name=name)
def load_plugin(
- path: str, config: dict | None = None, recursive_guard: set[str] | None = None, static: bool = False
+ path: str, config: dict | None = None, recursive_guard: set[str] | None = None, prelude: bool = False
) -> Plugin | None:
"""
以导入路径方式加载模块
@@ -39,7 +36,7 @@ def load_plugin(
path (str): 模块路径
config (dict): 模块配置
recursive_guard (set[str]): 递归保护
- static (bool): 是否为静态插件
+ prelude (bool): 是否为前置插件
"""
if config is None:
_config = EntariConfig.instance.plugin.get(path)
@@ -58,7 +55,9 @@ def load_plugin(
return plugin_service._apply[path](config or {})
try:
conf = config or {}
- if static:
+ if "$static" in conf:
+ del conf["$static"]
+ if prelude:
conf["$static"] = True
mod = import_plugin(path, config=conf)
if not mod:
@@ -119,11 +118,19 @@ def metadata(data: PluginMetadata):
def plugin_config() -> dict[str, Any]:
+ """获取当前插件的配置"""
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
return plugin.config
+def declare_static():
+ """声明当前插件为静态插件"""
+ if not (plugin := _current_plugin.get(None)):
+ raise LookupError("no plugin context found")
+ plugin.is_static = True
+
+
def find_plugin(name: str) -> Plugin | None:
return plugin_service.plugins.get(name)
diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py
index cfa007b..aab6cf3 100644
--- a/arclet/entari/plugin/model.py
+++ b/arclet/entari/plugin/model.py
@@ -20,12 +20,10 @@
from ..logger import log
from .service import plugin_service
-if TYPE_CHECKING:
- from ..event.base import BasedEvent
-
_current_plugin: ContextModel[Plugin] = ContextModel("_current_plugin")
T = TypeVar("T")
+TE = TypeVar("TE")
R = TypeVar("R")
@@ -41,8 +39,8 @@ class PluginDispatcher:
def __init__(
self,
plugin: Plugin,
- *events: type[BasedEvent],
- predicate: Callable[[BasedEvent], bool] | None = None,
+ *events: type[TE],
+ predicate: Callable[[TE], bool] | None = None,
name: str | None = None,
):
if len(events) == 1:
@@ -61,7 +59,7 @@ def __init__(
def waiter(
self,
- *events: type[BasedEvent],
+ *events: Any,
providers: Sequence[Provider | type[Provider]] | None = None,
auxiliaries: list[BaseAuxiliary] | None = None,
priority: int = 15,
@@ -237,9 +235,7 @@ def dispose(self):
del plugin_service.plugins[self.id]
del self.module
- def dispatch(
- self, *events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None
- ):
+ def dispatch(self, *events: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None):
if self.is_static:
raise StaticPluginDispatchError("static plugin cannot dispatch events")
disp = PluginDispatcher(self, *events, predicate=predicate, name=name)
diff --git a/arclet/entari/scheduler.py b/arclet/entari/scheduler.py
index 79c4945..cf5856e 100644
--- a/arclet/entari/scheduler.py
+++ b/arclet/entari/scheduler.py
@@ -1,4 +1,5 @@
import asyncio
+from asyncio.events import _get_running_loop # type: ignore
from datetime import datetime, timedelta
from traceback import print_exc
from typing import Callable, Literal
@@ -9,7 +10,7 @@
from launart import Launart, Service, any_completed
from launart.status import Phase
-from .plugin import RootlessPlugin
+from .plugin import RootlessPlugin, _current_plugin
class _ScheduleEvent:
@@ -28,7 +29,7 @@ def __init__(self, supplier: Callable[[], timedelta], sub_id: str):
self.handle = None
def start(self, queue: asyncio.Queue):
- loop = asyncio.get_event_loop()
+ loop = asyncio.get_running_loop()
self.handle = loop.call_later(self.supplier().total_seconds(), queue.put_nowait, self)
def cancel(self):
@@ -63,11 +64,16 @@ async def fetch(self):
except Exception:
print_exc()
- def schedule(self, timer: Callable[[], timedelta]):
+ def schedule(self, timer: Callable[[], timedelta], once: bool = False):
def wrapper(func: Callable):
- sub = pub.register(func)
+ if plugin := _current_plugin.get():
+ sub = plugin.dispatch(_ScheduleEvent).register(func, temporary=once)
+ else:
+ sub = pub.register(func, temporary=once)
self.timers[sub.id] = TimerTask(timer, sub.id)
+ if _get_running_loop():
+ self.timers[sub.id].start(self.queue)
return sub
return wrapper
@@ -205,3 +211,8 @@ def every(
"hour": every_hours,
}
return service.schedule(_TIMER_MAPPING[mode](value))
+
+
+def invoke(delay: float):
+ """延迟执行"""
+ return service.schedule(lambda: timedelta(seconds=delay), once=True)
diff --git a/arclet/entari/session.py b/arclet/entari/session.py
index a954635..4ed36e0 100644
--- a/arclet/entari/session.py
+++ b/arclet/entari/session.py
@@ -10,24 +10,24 @@
from satori.element import Element
from satori.model import Channel, Guild, Member, MessageReceipt, PageResult, Role, User
-from .event.protocol import Event, FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent, Reply
+from .event.base import FriendRequestEvent, GuildMemberRequestEvent, GuildRequestEvent, MessageEvent, Reply, SatoriEvent
from .event.send import SendRequest, SendResponse
from .message import MessageChain
-TEvent = TypeVar("TEvent", bound=Event)
+TEvent = TypeVar("TEvent", bound=SatoriEvent)
class EntariProtocol(ApiProtocol):
async def send_message(
- self, channel: str | Channel, message: str | Iterable[str | Element], source: Event | None = None
+ self, channel: str | Channel, message: str | Iterable[str | Element], source: SatoriEvent | None = None
) -> list[MessageReceipt]:
"""发送消息。返回一个 `MessageReceipt` 对象构成的数组。
Args:
channel (str | Channel): 要发送的频道 ID
message (str | Iterable[str | Element]): 要发送的消息
- source (Event | None): 源事件
+ source (SatoriEvent | None): 源事件
Returns:
list[MessageReceipt]: `MessageReceipt` 对象构成的数组
@@ -36,14 +36,14 @@ async def send_message(
return await self.message_create(channel_id=channel_id, content=message, source=source)
async def send_private_message(
- self, user: str | User, message: str | Iterable[str | Element], source: Event | None = None
+ self, user: str | User, message: str | Iterable[str | Element], source: SatoriEvent | None = None
) -> list[MessageReceipt]:
"""发送私聊消息。返回一个 `MessageReceipt` 对象构成的数组。
Args:
user (str | User): 要发送的用户 ID
message (str | Iterable[str | Element]): 要发送的消息
- source (Event | None): 源事件
+ source (SatoriEvent | None): 源事件
Returns:
list[MessageReceipt]: `MessageReceipt` 对象构成的数组
@@ -53,14 +53,14 @@ async def send_private_message(
return await self.message_create(channel_id=channel.id, content=message, source=source)
async def message_create(
- self, channel_id: str, content: str | Iterable[str | Element], source: Event | None = None
+ self, channel_id: str, content: str | Iterable[str | Element], source: SatoriEvent | None = None
) -> list[MessageReceipt]:
"""发送消息。返回一个 `MessageReceipt` 对象构成的数组。
Args:
channel_id (str): 频道 ID
content (str | Iterable[str | Element]): 消息内容
- source (Event | None): 源事件
+ source (SatoriEvent | None): 源事件
Returns:
list[MessageReceipt]: `MessageReceipt` 对象构成的数组
@@ -242,14 +242,12 @@ async def message_create(
raise RuntimeError("Event cannot be replied to!")
return await self.account.protocol.send_message(self.context.channel.id, content, self.context)
- async def message_delete(self) -> None:
+ async def message_delete(self, message_id: str) -> None:
if not self.context.channel:
raise RuntimeError("Event cannot be replied to!")
- if not self.context.message:
- raise RuntimeError("Event cannot update message")
await self.account.protocol.message_delete(
self.context.channel.id,
- self.context.message.id,
+ message_id,
)
async def message_update(
diff --git a/example_plugin.py b/example_plugin.py
index edc3654..a247a24 100644
--- a/example_plugin.py
+++ b/example_plugin.py
@@ -10,7 +10,7 @@
metadata,
keeping,
scheduler,
- Entari,
+ # Entari,
)
from arclet.entari.filter import Interval
@@ -29,7 +29,7 @@ async def cleanup():
print("example: Cleanup")
-disp_message = MessageCreatedEvent.dispatch()
+disp_message = plug.dispatch(MessageCreatedEvent)
@disp_message
@@ -38,17 +38,20 @@ async def _(msg: MessageChain, session: Session):
content = msg.extract_plain_text()
if re.match(r"(.{0,3})(上传|设定)(.{0,3})(上传|设定)(.{0,3})", content):
return await session.send("上传设定的帮助是...")
+ if content == "test":
+ resp = await session.send("This message will recall in 5s...")
+ @scheduler.invoke(5)
+ async def _():
+ await session.message_delete(resp[0].id)
-disp_message1 = plug.dispatch(MessageCreatedEvent)
-
-@disp_message1.on(auxiliaries=[Filter().public().to_me().and_(lambda sess: str(sess.content) == "aaa")])
+@disp_message.on(auxiliaries=[Filter().public().to_me().and_(lambda sess: str(sess.content) == "aaa")])
async def _(session: Session):
return await session.send("Filter: public message, to me, and content is 'aaa'")
-@disp_message1.on(auxiliaries=[Filter().public().to_me().not_(lambda sess: str(sess.content) == "aaa")])
+@disp_message.on(auxiliaries=[Filter().public().to_me().not_(lambda sess: str(sess.content) == "aaa")])
async def _(session: Session):
return await session.send("Filter: public message, to me, but content is not 'aaa'")
@@ -84,10 +87,9 @@ async def show(session: Session):
async def send_hook(message: MessageChain):
return message + "喵"
-
-@scheduler.cron("* * * * *")
-async def broadcast(app: Entari):
- for account in app.accounts.values():
- channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data]
- for channel in channels:
- await account.send_message(channel, "Hello, World!")
+# @scheduler.cron("* * * * *")
+# async def broadcast(app: Entari):
+# for account in app.accounts.values():
+# channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data]
+# for channel in channels:
+# await account.send_message(channel, "Hello, World!")