Skip to content

Commit

Permalink
✨ add hooks on plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Jul 14, 2024
1 parent 58427a5 commit 64f8549
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 18 deletions.
22 changes: 20 additions & 2 deletions arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
ProviderFactory,
global_providers,
)
from launart import Launart
from loguru import logger
from satori import LoginStatus
from satori.client import App
from satori.client.account import Account
from satori.client.protocol import ApiProtocol
Expand All @@ -22,7 +24,7 @@

from .command import _commands
from .event import MessageEvent, event_parse
from .plugin.model import _plugins
from .plugin.service import service
from .session import Session


Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(self, *configs: Config):
self.event_system = EventSystem()
self.event_system.register(_commands.publisher)
self.register(self.handle_event)
self.lifecycle(self.account_hook)
self._ref_tasks = set()

def on(
Expand All @@ -73,13 +76,17 @@ def on_message(
MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers
)

def ensure_manager(self, manager: Launart):
self.manager = manager
manager.add_component(service)

async def handle_event(self, account: Account, event: Event):
async def event_parse_task(connection: Account, raw: Event):
loop = asyncio.get_running_loop()
with suppress(NotImplementedError):
ev = event_parse(connection, raw)
self.event_system.publish(ev)
for plugin in _plugins.values():
for plugin in service.plugins.values():
for disp in plugin.dispatchers.values():
if not disp.validate(ev):
continue
Expand All @@ -93,3 +100,14 @@ async def event_parse_task(connection: Account, raw: Event):
logger.warning(f"received unsupported event {raw.type}: {raw}")

await event_parse_task(account, event)

async def account_hook(self, account: Account, state: LoginStatus):
_connected = []
_disconnected = []
for plug in service.plugins.values():
_connected.extend([func(account) for func in plug._connected])
_disconnected.extend([func(account) for func in plug._disconnected])
if state == LoginStatus.CONNECT:
await asyncio.gather(*_connected, return_exceptions=True)
elif state == LoginStatus.DISCONNECT:
await asyncio.gather(*_disconnected, return_exceptions=True)
11 changes: 6 additions & 5 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from .model import Plugin, PluginDispatcher
from .model import PluginMetadata as PluginMetadata
from .model import _current_plugin, _plugins
from .model import _current_plugin
from .module import import_plugin
from .service import service

if TYPE_CHECKING:
from ..event import Event
Expand All @@ -29,8 +30,8 @@ def load_plugin(path: str) -> Plugin | None:
Args:
path (str): 模块路径
"""
if path in _plugins:
return _plugins[path]
if path in service.plugins:
return service.plugins[path]
try:
mod = import_plugin(path)
if not mod:
Expand All @@ -53,7 +54,7 @@ def load_plugins(dir_: str | PathLike | Path):


def dispose(plugin: str):
if plugin not in _plugins:
if plugin not in service.plugins:
return
_plugin = _plugins[plugin]
_plugin = service.plugins[plugin]
_plugin.dispose()
40 changes: 35 additions & 5 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from __future__ import annotations

from collections.abc import Awaitable
from contextvars import ContextVar
from dataclasses import dataclass, field
from types import ModuleType
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable
from weakref import WeakValueDictionary, finalize

from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx
from arclet.letoderea.builtin.breakpoint import R
from arclet.letoderea.typing import TTarget
from satori.client import Account
from tarina import init_spec

from .service import service

if TYPE_CHECKING:
from ..event import Event

_current_plugin: ContextVar[Plugin | None] = ContextVar("_current_plugin", default=None)

_plugins: dict[str, Plugin] = {}


class PluginDispatcher(Publisher):
def __init__(
Expand Down Expand Up @@ -83,6 +85,10 @@ class PluginMetadata:
# component_endpoints: list[str] = field(default_factory=list)


_Lifespan = Callable[..., Awaitable[Any]]
_AccountUpdate = Callable[[Account], Awaitable[Any]]


@dataclass
class Plugin:
id: str
Expand All @@ -91,12 +97,33 @@ class Plugin:
metadata: PluginMetadata | None = None
_is_disposed: bool = False

_preparing: list[_Lifespan] = field(init=False, default_factory=list)
_cleanup: list[_Lifespan] = field(init=False, default_factory=list)
_connected: list[_AccountUpdate] = field(init=False, default_factory=list)
_disconnected: list[_AccountUpdate] = field(init=False, default_factory=list)

def on_prepare(self, func: _Lifespan):
self._preparing.append(func)
return func

def on_cleanup(self, func: _Lifespan):
self._cleanup.append(func)
return func

def on_connect(self, func: _AccountUpdate):
self._connected.append(func)
return func

def on_disconnect(self, func: _AccountUpdate):
self._disconnected.append(func)
return func

@staticmethod
def current() -> Plugin:
return _current_plugin.get() # type: ignore

def __post_init__(self):
_plugins[self.id] = self
service.plugins[self.id] = self
finalize(self, self.dispose)

@init_spec(PluginMetadata, True)
Expand All @@ -111,5 +138,8 @@ def dispose(self):
for disp in self.dispatchers.values():
disp.dispose()
self.dispatchers.clear()
del _plugins[self.id]
del service.plugins[self.id]
del self.module

def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | None = None):
return PluginDispatcher(self, *events, predicate=predicate)
9 changes: 5 additions & 4 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from types import ModuleType
from typing import Optional

from .model import Plugin, PluginMetadata, _current_plugin, _plugins
from .model import Plugin, PluginMetadata, _current_plugin
from .service import service


class PluginLoader(SourceFileLoader):
Expand All @@ -15,9 +16,9 @@ def __init__(self, fullname: str, path: str) -> None:
super().__init__(fullname, path)

def create_module(self, spec) -> Optional[ModuleType]:
if self.name in _plugins:
if self.name in service.plugins:
self.loaded = True
return _plugins[self.name].module
return service.plugins[self.name].module
return super().create_module(spec)

def exec_module(self, module: ModuleType) -> None:
Expand Down Expand Up @@ -94,7 +95,7 @@ def find_spec(
module_origin = module_spec.origin
if not module_origin:
return
if module_spec.name in _plugins:
if module_spec.name in service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin)
return module_spec
return
Expand Down
50 changes: 50 additions & 0 deletions arclet/entari/plugin/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
from typing import TYPE_CHECKING

from launart import Launart, Service
from launart.status import Phase
from loguru import logger

if TYPE_CHECKING:
from .model import Plugin


class PluginService(Service):
id = "arclet.entari.plugin_service"

plugins: dict[str, "Plugin"]

def __init__(self):
super().__init__()
self.plugins = {}

@property
def required(self) -> set[str]:
return set()

@property
def stages(self) -> set[Phase]:
return {"preparing", "cleanup"}

async def launch(self, manager: Launart):
_preparing = []
_cleanup = []
for plug in self.plugins.values():
_preparing.extend([func() for func in plug._preparing])
_cleanup.extend([func() for func in plug._cleanup])
async with self.stage("preparing"):
await asyncio.gather(*_preparing, return_exceptions=True)
async with self.stage("cleanup"):
await asyncio.gather(*_cleanup, return_exceptions=True)
ids = [*self.plugins.keys()]
for plug_id in ids:
plug = self.plugins[plug_id]
logger.debug(f"disposing plugin {plug.id}")
try:
plug.dispose()
except Exception as e:
logger.error(f"failed to dispose plugin {plug.id} caused by {e!r}")
self.plugins.pop(plug_id, None)


service = PluginService()
18 changes: 16 additions & 2 deletions example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
)
from arclet.entari.command import Match

Plugin.current().meta(__file__)
plug = Plugin.current().meta(__file__)


@plug.on_prepare
async def prepare():
print("Preparing")


@plug.on_cleanup
async def cleanup():
print("Cleanup")


disp_message = MessageCreatedEvent.dispatch()

Expand All @@ -26,16 +37,19 @@ async def _(msg: MessageChain):
return "上传设定的帮助是..."


disp_message1 = plug.dispatch(MessageCreatedEvent)


from satori import select, Author


@disp_message.on(auxiliaries=[is_public_message])
@disp_message1.on(auxiliaries=[is_public_message])
async def _(event: MessageCreatedEvent):
print(event.content)
if event.quote and (authors := select(event.quote, Author)):
author = authors[0]
reply_self = author.id == event.account.self_id
print(reply_self)


on_alc = command.mount(Alconna("echo", Args["content?", AllParam]))
Expand Down

0 comments on commit 64f8549

Please sign in to comment.