diff --git a/arclet/entari/core.py b/arclet/entari/core.py index 2c28bfe..21728d2 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -12,7 +12,9 @@ ProviderFactory, global_providers, ) +from launart import Launart from loguru import logger +from satori import LoginStatus from satori.client import App from satori.client.account import Account from satori.client.protocol import ApiProtocol @@ -22,7 +24,7 @@ from .command import _commands from .event import MessageEvent, event_parse -from .plugin.model import _plugins +from .plugin.service import service from .session import Session @@ -52,6 +54,7 @@ def __init__(self, *configs: Config): self.event_system = EventSystem() self.event_system.register(_commands.publisher) self.register(self.handle_event) + self.lifecycle(self.account_hook) self._ref_tasks = set() def on( @@ -73,13 +76,17 @@ def on_message( MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers ) + def ensure_manager(self, manager: Launart): + self.manager = manager + manager.add_component(service) + async def handle_event(self, account: Account, event: Event): async def event_parse_task(connection: Account, raw: Event): loop = asyncio.get_running_loop() with suppress(NotImplementedError): ev = event_parse(connection, raw) self.event_system.publish(ev) - for plugin in _plugins.values(): + for plugin in service.plugins.values(): for disp in plugin.dispatchers.values(): if not disp.validate(ev): continue @@ -93,3 +100,14 @@ async def event_parse_task(connection: Account, raw: Event): logger.warning(f"received unsupported event {raw.type}: {raw}") await event_parse_task(account, event) + + async def account_hook(self, account: Account, state: LoginStatus): + _connected = [] + _disconnected = [] + for plug in service.plugins.values(): + _connected.extend([func(account) for func in plug._connected]) + _disconnected.extend([func(account) for func in plug._disconnected]) + if state == LoginStatus.CONNECT: + await asyncio.gather(*_connected, return_exceptions=True) + elif state == LoginStatus.DISCONNECT: + await asyncio.gather(*_disconnected, return_exceptions=True) diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index 62007df..0b534b6 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -8,8 +8,9 @@ from .model import Plugin, PluginDispatcher from .model import PluginMetadata as PluginMetadata -from .model import _current_plugin, _plugins +from .model import _current_plugin from .module import import_plugin +from .service import service if TYPE_CHECKING: from ..event import Event @@ -29,8 +30,8 @@ def load_plugin(path: str) -> Plugin | None: Args: path (str): 模块路径 """ - if path in _plugins: - return _plugins[path] + if path in service.plugins: + return service.plugins[path] try: mod = import_plugin(path) if not mod: @@ -53,7 +54,7 @@ def load_plugins(dir_: str | PathLike | Path): def dispose(plugin: str): - if plugin not in _plugins: + if plugin not in service.plugins: return - _plugin = _plugins[plugin] + _plugin = service.plugins[plugin] _plugin.dispose() diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index 8790274..0f29d2b 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -1,23 +1,25 @@ from __future__ import annotations +from collections.abc import Awaitable from contextvars import ContextVar from dataclasses import dataclass, field from types import ModuleType -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from weakref import WeakValueDictionary, finalize from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx from arclet.letoderea.builtin.breakpoint import R from arclet.letoderea.typing import TTarget +from satori.client import Account from tarina import init_spec +from .service import service + if TYPE_CHECKING: from ..event import Event _current_plugin: ContextVar[Plugin | None] = ContextVar("_current_plugin", default=None) -_plugins: dict[str, Plugin] = {} - class PluginDispatcher(Publisher): def __init__( @@ -83,6 +85,10 @@ class PluginMetadata: # component_endpoints: list[str] = field(default_factory=list) +_Lifespan = Callable[..., Awaitable[Any]] +_AccountUpdate = Callable[[Account], Awaitable[Any]] + + @dataclass class Plugin: id: str @@ -91,12 +97,33 @@ class Plugin: metadata: PluginMetadata | None = None _is_disposed: bool = False + _preparing: list[_Lifespan] = field(init=False, default_factory=list) + _cleanup: list[_Lifespan] = field(init=False, default_factory=list) + _connected: list[_AccountUpdate] = field(init=False, default_factory=list) + _disconnected: list[_AccountUpdate] = field(init=False, default_factory=list) + + def on_prepare(self, func: _Lifespan): + self._preparing.append(func) + return func + + def on_cleanup(self, func: _Lifespan): + self._cleanup.append(func) + return func + + def on_connect(self, func: _AccountUpdate): + self._connected.append(func) + return func + + def on_disconnect(self, func: _AccountUpdate): + self._disconnected.append(func) + return func + @staticmethod def current() -> Plugin: return _current_plugin.get() # type: ignore def __post_init__(self): - _plugins[self.id] = self + service.plugins[self.id] = self finalize(self, self.dispose) @init_spec(PluginMetadata, True) @@ -111,5 +138,8 @@ def dispose(self): for disp in self.dispatchers.values(): disp.dispose() self.dispatchers.clear() - del _plugins[self.id] + del service.plugins[self.id] del self.module + + def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | None = None): + return PluginDispatcher(self, *events, predicate=predicate) diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 166381b..693f720 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -6,7 +6,8 @@ from types import ModuleType from typing import Optional -from .model import Plugin, PluginMetadata, _current_plugin, _plugins +from .model import Plugin, PluginMetadata, _current_plugin +from .service import service class PluginLoader(SourceFileLoader): @@ -15,9 +16,9 @@ def __init__(self, fullname: str, path: str) -> None: super().__init__(fullname, path) def create_module(self, spec) -> Optional[ModuleType]: - if self.name in _plugins: + if self.name in service.plugins: self.loaded = True - return _plugins[self.name].module + return service.plugins[self.name].module return super().create_module(spec) def exec_module(self, module: ModuleType) -> None: @@ -94,7 +95,7 @@ def find_spec( module_origin = module_spec.origin if not module_origin: return - if module_spec.name in _plugins: + if module_spec.name in service.plugins: module_spec.loader = PluginLoader(fullname, module_origin) return module_spec return diff --git a/arclet/entari/plugin/service.py b/arclet/entari/plugin/service.py new file mode 100644 index 0000000..175d6b2 --- /dev/null +++ b/arclet/entari/plugin/service.py @@ -0,0 +1,50 @@ +import asyncio +from typing import TYPE_CHECKING + +from launart import Launart, Service +from launart.status import Phase +from loguru import logger + +if TYPE_CHECKING: + from .model import Plugin + + +class PluginService(Service): + id = "arclet.entari.plugin_service" + + plugins: dict[str, "Plugin"] + + def __init__(self): + super().__init__() + self.plugins = {} + + @property + def required(self) -> set[str]: + return set() + + @property + def stages(self) -> set[Phase]: + return {"preparing", "cleanup"} + + async def launch(self, manager: Launart): + _preparing = [] + _cleanup = [] + for plug in self.plugins.values(): + _preparing.extend([func() for func in plug._preparing]) + _cleanup.extend([func() for func in plug._cleanup]) + async with self.stage("preparing"): + await asyncio.gather(*_preparing, return_exceptions=True) + async with self.stage("cleanup"): + await asyncio.gather(*_cleanup, return_exceptions=True) + ids = [*self.plugins.keys()] + for plug_id in ids: + plug = self.plugins[plug_id] + logger.debug(f"disposing plugin {plug.id}") + try: + plug.dispose() + except Exception as e: + logger.error(f"failed to dispose plugin {plug.id} caused by {e!r}") + self.plugins.pop(plug_id, None) + + +service = PluginService() diff --git a/example_plugin.py b/example_plugin.py index 1dda573..b90681a 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -13,7 +13,18 @@ ) from arclet.entari.command import Match -Plugin.current().meta(__file__) +plug = Plugin.current().meta(__file__) + + +@plug.on_prepare +async def prepare(): + print("Preparing") + + +@plug.on_cleanup +async def cleanup(): + print("Cleanup") + disp_message = MessageCreatedEvent.dispatch() @@ -26,16 +37,19 @@ async def _(msg: MessageChain): return "上传设定的帮助是..." +disp_message1 = plug.dispatch(MessageCreatedEvent) + from satori import select, Author -@disp_message.on(auxiliaries=[is_public_message]) +@disp_message1.on(auxiliaries=[is_public_message]) async def _(event: MessageCreatedEvent): print(event.content) if event.quote and (authors := select(event.quote, Author)): author = authors[0] reply_self = author.id == event.account.self_id + print(reply_self) on_alc = command.mount(Alconna("echo", Args["content?", AllParam]))