Skip to content

Commit

Permalink
✨ impl before_send hook
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Nov 26, 2024
1 parent e73809a commit c63efb3
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 124 deletions.
1 change: 0 additions & 1 deletion arclet/entari/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arclet.letoderea import bind as bind
from arclet.letoderea import es as es
from satori import ArgvInteraction as ArgvInteraction
from satori import At as At
from satori import Audio as Audio
Expand Down
9 changes: 5 additions & 4 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from tarina.string import split
from tarina.trie import CharTrie

from ..event import MessageCreatedEvent
from ..event.command import CommandExecute
from ..event.protocol import MessageCreatedEvent
from ..message import MessageChain
from .argv import MessageArgv # noqa: F401
from .model import CommandResult, Match, Query
from .plugin import CommandExecute, mount
from .plugin import mount
from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd

T = TypeVar("T")
Expand All @@ -28,7 +29,7 @@ class EntariCommands:

def __init__(self, need_tome: bool = False, remove_tome: bool = True):
self.trie: CharTrie[Subscriber] = CharTrie()
self.publisher = Publisher("EntariCommands", MessageCreatedEvent)
self.publisher = Publisher("entari.command", MessageCreatedEvent)
self.publisher.providers.append(AlconnaProviderFactory())
self.need_tome = need_tome
self.remove_tome = remove_tome
Expand Down Expand Up @@ -199,7 +200,7 @@ def config_commands(need_tome: bool = False, remove_tome: bool = True):


async def execute(message: Union[str, MessageChain]):
return await es.post(CommandExecute(message), "entari.command/command_execute")
return await es.post(CommandExecute(message), "entari.event/command_execute")


__all__ = ["_commands", "config_commands", "Match", "Query", "execute", "CommandResult", "mount", "command", "on"]
72 changes: 17 additions & 55 deletions arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from arclet.alconna import Alconna, Arparma, command_manager, output_manager
from arclet.letoderea import AuxType, BaseAuxiliary, Contexts, Interface, Provider, ProviderFactory, Scope, es
from arclet.alconna import Alconna, command_manager
from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory

from ..event import MessageCreatedEvent
from ..message import MessageChain
from ..event.command import pub as execute_handles
from ..plugin.model import Plugin, PluginDispatcher
from .model import CommandResult, Match, Query
from .provider import AlconnaProviderFactory, AlconnaSuppiler, Assign, MessageJudger, _seminal
from .model import Match, Query
from .provider import AlconnaProviderFactory, AlconnaSuppiler, Assign, ExecuteSuppiler, MessageJudger, _seminal


@dataclass
class CommandExecute:
command: str | MessageChain

async def gather(self, context: Contexts):
if isinstance(self.command, str):
context["command"] = MessageChain(self.command)
else:
context["command"] = self.command

class CommandProvider(Provider[MessageChain]):
async def __call__(self, context: Contexts):
return context.get("command")


execute_handles = es.define("entari.command/command_execute", CommandExecute)
execute_handles.bind(AlconnaProviderFactory())


class ExecuteJudger(BaseAuxiliary):
def __init__(self, cmd: Alconna):
self.cmd = cmd
super().__init__(AuxType.supply, priority=1)

async def __call__(self, scope: Scope, interface: Interface):
message = interface.query(MessageChain, "command")
with output_manager.capture(self.cmd.name) as cap:
output_manager.set_action(lambda x: x, self.cmd.name)
try:
_res = self.cmd.parse(message)
except Exception as e:
_res = Arparma(self.cmd._hash, message, False, error_info=e)
may_help_text: str | None = cap.get("output", None)
result = CommandResult(self.cmd, _res, may_help_text)
return interface.update(alc_result=result)

@property
def scopes(self) -> set[Scope]:
return {Scope.prepare}

@property
def id(self) -> str:
return "entari.command/command_execute_judger"


class AlconnaPluginDispatcher(PluginDispatcher):

def __init__(
Expand All @@ -70,8 +26,8 @@ def __init__(
self.supplier = AlconnaSuppiler(command, need_tome, remove_tome)
super().__init__(plugin, MessageCreatedEvent)

self.bind(MessageJudger(), self.supplier)
self.bind(AlconnaProviderFactory())
self.publisher.bind(MessageJudger(), self.supplier)
self.publisher.bind(AlconnaProviderFactory())

def assign(
self,
Expand Down Expand Up @@ -99,8 +55,14 @@ def on_execute(
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
_auxiliaries = auxiliaries or []
_auxiliaries.append(ExecuteJudger(self.supplier.cmd))
return execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)
_auxiliaries.append(ExecuteSuppiler(self.supplier.cmd))

def wrapper(func):
sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers)
self._subscribers.append(sub)
return sub

return wrapper

Match = Match
Query = Query
Expand All @@ -110,7 +72,7 @@ def mount(cmd: Alconna, need_tome: bool = False, remove_tome: bool = True) -> Al
if not (plugin := Plugin.current()):
raise LookupError("no plugin context found")
disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome)
if disp.id in plugin.dispatchers:
if disp.publisher.id in plugin.dispatchers:
return plugin.dispatchers[disp.id] # type: ignore
plugin.dispatchers[disp.id] = disp
plugin.dispatchers[disp.publisher.id] = disp
return disp
28 changes: 27 additions & 1 deletion arclet/entari/command/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,33 @@ def scopes(self) -> set[Scope]:

@property
def id(self) -> str:
return "entari.command/alconna_supplier"
return "entari.command/common_supplier"


class ExecuteSuppiler(SupplyAuxiliary):
def __init__(self, cmd: Alconna):
self.cmd = cmd
super().__init__(priority=1)

async def __call__(self, scope: Scope, interface: Interface):
message = interface.query(MessageChain, "command")
with output_manager.capture(self.cmd.name) as cap:
output_manager.set_action(lambda x: x, self.cmd.name)
try:
_res = self.cmd.parse(message)
except Exception as e:
_res = Arparma(self.cmd._hash, message, False, error_info=e)
may_help_text: Optional[str] = cap.get("output", None)
result = CommandResult(self.cmd, _res, may_help_text)
return interface.update(alc_result=result)

@property
def scopes(self) -> set[Scope]:
return {Scope.prepare}

@property
def id(self) -> str:
return "entari.command/execute_supplier"


class AlconnaProvider(Provider[Any]):
Expand Down
4 changes: 3 additions & 1 deletion arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .command import _commands
from .config import Config as EntariConfig
from .event import MessageCreatedEvent, event_parse
from .event.protocol import MessageCreatedEvent, event_parse
from .plugin import load_plugin
from .plugin.service import plugin_service
from .session import Session
Expand All @@ -34,6 +34,8 @@ def validate(self, param: Param):
return get_origin(param.annotation) == Session

async def __call__(self, context: Contexts):
if "session" in context and isinstance(context["session"], Session):
return context["session"]
if "$origin_event" in context and "$account" in context:
return Session(context["$account"], context["$event"])

Expand Down
2 changes: 2 additions & 0 deletions arclet/entari/event/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .protocol import MessageCreatedEvent as MessageCreatedEvent
from .protocol import MessageEvent as MessageEvent
14 changes: 14 additions & 0 deletions arclet/entari/event/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations

from typing import Callable, TypeVar

from ..plugin import dispatch

TE = TypeVar("TE", bound="BasedEvent")


class BasedEvent:
@classmethod
def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None, name: str | None = None):
name = name or getattr(cls, "__disp_name__", None)
return dispatch(cls, predicate=predicate, name=name) # type: ignore
27 changes: 27 additions & 0 deletions arclet/entari/event/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass
from typing import Union

from arclet.letoderea import Contexts, Provider, es

from ..message import MessageChain
from .base import BasedEvent


@dataclass
class CommandExecute(BasedEvent):
command: Union[str, MessageChain]

async def gather(self, context: Contexts):
if isinstance(self.command, str):
context["command"] = MessageChain(self.command)
else:
context["command"] = self.command

class CommandProvider(Provider[MessageChain]):
async def __call__(self, context: Contexts):
return context.get("command")

__disp_name__ = "entari.event/command_execute"


pub = es.define("entari.event/command_execute", CommandExecute, lambda x: {"command": x.command, "message": x.command})
29 changes: 7 additions & 22 deletions arclet/entari/event.py → arclet/entari/event/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, Callable, ClassVar, Generic, TypeVar
from typing import Any, ClassVar, Generic, TypeVar

from arclet.letoderea import Contexts, Param, Provider
from satori import ArgvInteraction, ButtonInteraction, Channel
Expand All @@ -11,10 +11,9 @@
from satori.model import LoginType, MessageObject
from tarina import gen_subclass

from .message import MessageChain
from .plugin import dispatch
from ..message import MessageChain
from .base import BasedEvent

TE = TypeVar("TE", bound="Event")
T = TypeVar("T")
D = TypeVar("D")

Expand All @@ -25,6 +24,8 @@ def __init__(self, key: str | None = None):

def __set_name__(self, owner: type[Event], name: str):
self.key = self.key or name
if name not in ("id", "timestamp"):
owner._attrs.add(name)

def __get__(self, instance: Event, owner: type[Event]) -> T:
return getattr(instance._origin, self.key, None) # type: ignore
Expand All @@ -37,8 +38,9 @@ def attr(key: str | None = None) -> Any:
return Attr(key)


class Event:
class Event(BasedEvent):
type: ClassVar[EventType]
_attrs: ClassVar[set[str]] = set()
_origin: SatoriEvent
account: Account

Expand All @@ -55,27 +57,10 @@ class Event:
role: Role | None = attr()
user: User | None = attr()

_attrs: ClassVar[set[str]] = {
"argv",
"button",
"channel",
"guild",
"login",
"member",
"message",
"operator",
"role",
"user",
}

def __init__(self, account: Account, origin: SatoriEvent):
self.account = account
self._origin = origin

@classmethod
def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None):
return dispatch(cls, predicate=predicate) # type: ignore

async def gather(self, context: Contexts):
context["$account"] = self.account
context["$origin_event"] = self._origin
Expand Down
23 changes: 23 additions & 0 deletions arclet/entari/event/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from arclet.letoderea import deref, es, provide
from arclet.letoderea.ref import generate

from ..message import MessageChain
from .base import BasedEvent

if TYPE_CHECKING:
from ..session import Session


@dataclass
class SendRequest(BasedEvent):
session: "Session"
message: "MessageChain"

__disp_name__ = "entari.event/before_send"


pub = es.define("entari.event/before_send", SendRequest, lambda x: {"session": x.session, "message": x.message})
pub.bind(provide(MessageChain, call=generate(deref(SendRequest).message)))
6 changes: 3 additions & 3 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from .service import plugin_service

if TYPE_CHECKING:
from ..event import Event
from ..event.base import BasedEvent


def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = None):
def dispatch(*events: type[BasedEvent], predicate: Callable[[BasedEvent], bool] | None = None, name: str | None = None):
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
return plugin.dispatch(*events, predicate=predicate)
return plugin.dispatch(*events, predicate=predicate, name=name)


def load_plugin(path: str, config: dict | None = None, recursive_guard: set[str] | None = None) -> Plugin | None:
Expand Down
Loading

0 comments on commit c63efb3

Please sign in to comment.