From 304486ea93fb29e96f96619db5931b90a82b382a Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Fri, 13 Dec 2024 17:49:45 +0800 Subject: [PATCH] :sparkles: filter from config --- arclet/entari/_subscriber.py | 28 ------------ arclet/entari/builtins/auto_reload.py | 64 +++++++++++++++++++-------- arclet/entari/command/__init__.py | 54 ++++++++++++---------- arclet/entari/command/plugin.py | 4 +- arclet/entari/core.py | 20 ++++++--- arclet/entari/event/config.py | 3 +- arclet/entari/filter/common.py | 46 +++++++++++++++---- arclet/entari/plugin/model.py | 28 +++++++----- arclet/entari/plugin/module.py | 14 +++++- arclet/entari/plugin/service.py | 25 ++++++++++- example_plugin.py | 6 +-- 11 files changed, 186 insertions(+), 106 deletions(-) delete mode 100644 arclet/entari/_subscriber.py diff --git a/arclet/entari/_subscriber.py b/arclet/entari/_subscriber.py deleted file mode 100644 index 774e801..0000000 --- a/arclet/entari/_subscriber.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Callable, TypeVar - -from arclet.letoderea import Subscriber -from arclet.letoderea.typing import TTarget - -T = TypeVar("T") - - -class SubscribeLoader: - sub: Subscriber - - def __init__(self, func: TTarget[T], caller: Callable[[TTarget[T]], Subscriber[T]]): - self.func = func - self.caller = caller - self.loaded = False - - def load(self): - if not self.loaded: - self.sub = self.caller(self.func) - self.loaded = True - return self.sub - - def dispose(self): - if self.loaded: - self.sub.dispose() - self.loaded = False - del self.func - del self.caller diff --git a/arclet/entari/builtins/auto_reload.py b/arclet/entari/builtins/auto_reload.py index 04ee104..f657ec6 100644 --- a/arclet/entari/builtins/auto_reload.py +++ b/arclet/entari/builtins/auto_reload.py @@ -30,6 +30,21 @@ class Config: logger = log.wrapper("[AutoReload]") +def detect_filter_change(old: dict, new: dict): + added = set(new) - set(old) + removed = set(old) - set(new) + changed = {key for key in set(new) & set(old) if new[key] != old[key]} + if "$allow" in removed: + allow = {} + else: + allow = new["$allow"] + if "$deny" in removed: + deny = {} + else: + deny = new["$deny"] + return allow, deny, not ((added | removed | changed) - {"$allow", "$deny"}) + + class Watcher(Service): id = "watcher" @@ -61,7 +76,6 @@ async def watch(self): dispose_plugin(pid) if plugin := load_plugin(pid): logger("INFO", f"Reloaded {plugin.id!r}") - plugin._load() await plugin._startup() await plugin._ready() del plugin @@ -72,7 +86,6 @@ async def watch(self): logger("INFO", f"Detected change in {change[1]!r} which failed to reload, retrying...") if plugin := load_plugin(self.fail[change[1]]): logger("INFO", f"Reloaded {plugin.id!r}") - plugin._load() await plugin._startup() await plugin._ready() del plugin @@ -102,17 +115,21 @@ async def watch_config(self): f"Basic config {key!r} changed from {old_basic[key]!r} " f"to {EntariConfig.instance.basic[key]!r}", ) - await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key])) + await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key], old_basic[key])) + for key in set(EntariConfig.instance.basic) - set(old_basic): + logger("DEBUG", f"Basic config {key!r} appended") + await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key])) for plugin_name in old_plugin: pid = plugin_name.replace("::", "arclet.entari.builtins.") if ( plugin_name not in EntariConfig.instance.plugin or EntariConfig.instance.plugin[plugin_name] is False - ) and (plugin := find_plugin(pid)): - await plugin._cleanup() - del plugin - dispose_plugin(pid) - logger("INFO", f"Disposed plugin {pid!r}") + ): + if plugin := find_plugin(pid): + await plugin._cleanup() + del plugin + dispose_plugin(pid) + logger("INFO", f"Disposed plugin {pid!r}") continue if old_plugin[plugin_name] != EntariConfig.instance.plugin[plugin_name]: logger( @@ -120,19 +137,31 @@ async def watch_config(self): f"Plugin {plugin_name!r} config changed from {old_plugin[plugin_name]!r} " f"to {EntariConfig.instance.plugin[plugin_name]!r}", ) - res = await es.post( - ConfigReload("plugin", plugin_name, EntariConfig.instance.plugin[plugin_name]) - ) - if res and res.value: - logger("DEBUG", f"Plugin {pid!r} config change handled by itself.") - continue + if isinstance(old_plugin[plugin_name], bool): + old_conf = {} + else: + old_conf: dict = old_plugin[plugin_name] # type: ignore + if isinstance(EntariConfig.instance.plugin[plugin_name], bool): + new_conf = {} + else: + new_conf: dict = EntariConfig.instance.plugin[plugin_name] # type: ignore if plugin := find_plugin(pid): + allow, deny, only_filter = detect_filter_change(old_conf, new_conf) + plugin.update_filter(allow, deny) + if only_filter: + logger("DEBUG", f"Plugin {pid!r} config only changed filter.") + continue + res = await es.post( + ConfigReload("plugin", plugin_name, new_conf, old_conf), + ) + if res and res.value: + logger("DEBUG", f"Plugin {pid!r} config change handled by itself.") + continue logger("INFO", f"Detected {pid!r}'s config change, reloading...") plugin_file = str(plugin.module.__file__) await plugin._cleanup() dispose_plugin(plugin_name) - if plugin := load_plugin(plugin_name): - plugin._load() + if plugin := load_plugin(plugin_name, new_conf): await plugin._startup() await plugin._ready() logger("INFO", f"Reloaded {plugin.id!r}") @@ -142,12 +171,11 @@ async def watch_config(self): self.fail[plugin_file] = pid else: logger("INFO", f"Detected {pid!r} appended, loading...") - load_plugin(plugin_name) + load_plugin(plugin_name, new_conf) if new := (set(EntariConfig.instance.plugin) - set(old_plugin)): for plugin_name in new: if not (plugin := load_plugin(plugin_name)): continue - plugin._load() await plugin._startup() await plugin._ready() del plugin diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index f6a37d4..61e4852 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -55,7 +55,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts): if not msg: return if matches := list(self.trie.prefixes(msg)): - results = await asyncio.gather(*(res.value.handle(ctx.copy(), inner=True) for res in matches if res.value)) + results = await asyncio.gather(*(res.value.handle(ctx.copy()) for res in matches if res.value)) for result in results: if result is not None: await session.send(result) @@ -67,7 +67,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts): command_manager.find_shortcut(get_cmd(value), data) except ValueError: continue - result = await value.handle(ctx.copy(), inner=True) + result = await value.handle(ctx.copy()) if result is not None: await session.send(result) @@ -143,7 +143,7 @@ def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]: f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern) ) auxiliaries.insert(0, AlconnaSuppiler(_command)) - target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) + target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers) self.publisher.remove_subscriber(target) self.trie[key] = target @@ -152,30 +152,34 @@ def _remove(_): self.trie.pop(key, None) # type: ignore target._dispose = _remove + return target + + _command = cast(Alconna, command) + if not isinstance(command.command, str): + raise TypeError("Command name must be a string.") + _command.reset_namespace(self.__namespace__) + auxiliaries.insert(0, AlconnaSuppiler(_command)) + keys = [] + if not _command.prefixes: + keys.append(_command.command) + elif not all(isinstance(i, str) for i in _command.prefixes): + raise TypeError("Command prefixes must be a list of string.") else: - auxiliaries.insert(0, AlconnaSuppiler(command)) - target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) - self.publisher.remove_subscriber(target) - if not isinstance(command.command, str): - raise TypeError("Command name must be a string.") - keys = [] - if not command.prefixes: - self.trie[command.command] = target - keys.append(command.command) - elif not all(isinstance(i, str) for i in command.prefixes): - raise TypeError("Command prefixes must be a list of string.") - else: - for prefix in cast(list[str], command.prefixes): - self.trie[prefix + command.command] = target - keys.append(prefix + command.command) + for prefix in cast(list[str], _command.prefixes): + keys.append(prefix + _command.command) - def _remove(_): - command_manager.delete(get_cmd(_)) - for key in keys: - self.trie.pop(key, None) # type: ignore + target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers) + self.publisher.remove_subscriber(target) - target._dispose = _remove - command.reset_namespace(self.__namespace__) + for _key in keys: + self.trie[_key] = target + + def _remove(_): + command_manager.delete(get_cmd(_)) + for _key in keys: + self.trie.pop(_key, None) # type: ignore + + target._dispose = _remove return target return wrapper @@ -209,6 +213,8 @@ def _(plg: RootlessPlugin): if "use_config_prefix" in plg.config: _commands.judge.use_config_prefix = plg.config["use_config_prefix"] + plg.dispatch(MessageCreatedEvent).handle(_commands.handle, auxiliaries=[_commands.judge]) + @plg.use(ConfigReload) def update(event: ConfigReload): if event.scope != "plugin": diff --git a/arclet/entari/command/plugin.py b/arclet/entari/command/plugin.py index c3bc1d6..bb2cc92 100644 --- a/arclet/entari/command/plugin.py +++ b/arclet/entari/command/plugin.py @@ -5,7 +5,6 @@ from arclet.alconna import Alconna, command_manager from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory -from .._subscriber import SubscribeLoader from ..event import MessageCreatedEvent from ..event.command import pub as execute_handles from ..plugin.model import Plugin, PluginDispatcher @@ -60,8 +59,7 @@ def on_execute( _auxiliaries.append(self.supplier) def wrapper(func): - caller = execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers) - sub = SubscribeLoader(func, caller) + sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers) self._subscribers.append(sub) return sub diff --git a/arclet/entari/core.py b/arclet/entari/core.py index 0687f27..fcfa92f 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -3,7 +3,7 @@ from contextlib import suppress import os -from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, es, global_providers +from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, Subscriber, es, global_providers from creart import it from launart import Launart, Service from satori import LoginStatus @@ -14,7 +14,6 @@ from satori.model import Event from tarina.generic import get_origin -from .command import _commands from .config import EntariConfig from .event.config import ConfigReload from .event.lifespan import AccountUpdate @@ -22,7 +21,7 @@ from .event.send import SendResponse from .logger import log from .plugin import load_plugin, plugin_config, requires -from .plugin.model import RootlessPlugin +from .plugin.model import Plugin, RootlessPlugin from .plugin.service import plugin_service from .session import EntariProtocol, Session @@ -55,7 +54,17 @@ async def __call__(self, context: Contexts): return context["account"] -global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider()]) +class PluginProvider(Provider[Plugin]): + async def __call__(self, context: Contexts): + subscriber: Subscriber = context["$subscriber"] + func = subscriber.callable_target + if hasattr(func, "__globals__") and "__plugin__" in func.__globals__: # type: ignore + return func.__globals__["__plugin__"] + if hasattr(func, "__module__"): + return plugin_service.plugins.get(func.__module__) + + +global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider(), PluginProvider()]) @RootlessPlugin.apply("record_message") @@ -117,6 +126,8 @@ def __init__( super().__init__(*configs, default_api_cls=EntariProtocol) if not hasattr(EntariConfig, "instance"): EntariConfig.load() + if "~commands" not in EntariConfig.instance.plugin: + EntariConfig.instance.plugin["~commands"] = True log.set_level(log_level) log.core.opt(colors=True).debug(f"Log level set to {log_level}") requires(*EntariConfig.instance.plugin) @@ -128,7 +139,6 @@ def __init__( self._ref_tasks = set() es.on(ConfigReload, self.reset_self) - es.on(MessageCreatedEvent, _commands.handle, auxiliaries=[_commands.judge]) def reset_self(self, scope, key, value): if scope != "basic": diff --git a/arclet/entari/event/config.py b/arclet/entari/event/config.py index 229a2dc..df75c9c 100644 --- a/arclet/entari/event/config.py +++ b/arclet/entari/event/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from arclet.letoderea import es @@ -11,6 +11,7 @@ class ConfigReload(BasedEvent): scope: str key: str value: Any + old: Optional[Any] = None __publisher__ = "entari.event/config_reload" __result_type__: type[bool] = bool diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py index 959c206..5b0e5b0 100644 --- a/arclet/entari/filter/common.py +++ b/arclet/entari/filter/common.py @@ -22,7 +22,7 @@ def __init__(self, *user_ids: str, priority: int = 10): async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: if not (user := await interface.query(User, "user", force_return=True)): return False - return user.id in self.user_ids + return user.id in self.user_ids if self.user_ids else True @property def scopes(self) -> set[Scope]: @@ -41,7 +41,7 @@ def __init__(self, *guild_ids: str, priority: int = 10): async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: if not (guild := await interface.query(Guild, "guild", force_return=True)): return False - return guild.id in self.guild_ids + return guild.id in self.guild_ids if self.guild_ids else True @property def scopes(self) -> set[Scope]: @@ -60,7 +60,7 @@ def __init__(self, *channel_ids: str, priority: int = 10): async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: if not (channel := await interface.query(Channel, "channel", force_return=True)): return False - return channel.id in self.channel_ids + return channel.id in self.channel_ids if self.channel_ids else True @property def scopes(self) -> set[Scope]: @@ -110,6 +110,21 @@ def id(self) -> str: _SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]] +_keys = { + "user", + "guild", + "channel", + "self", + "platform", + "direct", + "private", + "public", + "reply_me", + "notice_me", + "to_me", +} + +PATTERNS: TypeAlias = dict[str, Union[list[str], bool, "PATTERNS"]] class Filter(JudgeAuxiliary): @@ -164,29 +179,24 @@ def platform(self, *platforms: str) -> Self: self.steps.append(PlatformFilter(*platforms, priority=2)) return self - @property def direct(self) -> Self: self.steps.append(DirectMessageJudger(priority=8)) return self private = direct - @property def public(self) -> Self: self.steps.append(PublicMessageJudger(priority=8)) return self - @property def reply_me(self) -> Self: self.steps.append(ReplyMeJudger(priority=9)) return self - @property def notice_me(self) -> Self: self.steps.append(NoticeMeJudger(priority=10)) return self - @property def to_me(self) -> Self: self.steps.append(ToMeJudger(priority=11)) return self @@ -217,3 +227,23 @@ def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter": return new exclude = not_ + + @classmethod + def parse(cls, patterns: PATTERNS) -> Self: + fter = cls(priority=10) + for key, value in patterns.items(): + if key in _keys: + if isinstance(value, list): + getattr(fter, key)(*value) + elif isinstance(value, bool) and value: + getattr(fter, key)() + elif key in ("$and", "$or", "$not", "$intersect", "$union", "$exclude"): + op = key[1:] + if op in ("and", "or", "not"): + op += "_" + if not isinstance(value, dict): + raise ValueError(f"Expect a dict for operator {key}") + fter = getattr(fter, op)(cls.parse(value)) + else: + raise ValueError(f"Unknown key: {key}") + return fter diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index bb0808f..0098a02 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -15,8 +15,8 @@ from launart import Launart, Service from tarina import ContextModel -from .._subscriber import SubscribeLoader from ..event.lifespan import Cleanup, Ready, Startup +from ..filter import Filter from ..logger import log from .service import plugin_service @@ -53,7 +53,7 @@ def __init__( es.register(self.publisher) self.plugin = plugin self._events = events - self._subscribers: list[SubscribeLoader] = [] + self._subscribers: list[Subscriber] = [] def waiter( self, @@ -71,10 +71,6 @@ def wrapper(func: TTarget[R]): return wrapper - def _load(self): - for sub in self._subscribers: - sub.load() - def dispose(self): for sub in self._subscribers: sub.dispose() @@ -88,13 +84,13 @@ def register(self, func: Callable | None = None, **kwargs) -> Any: wrapper = self.publisher.register(**kwargs) if func: self.plugin.validate(func) # type: ignore - sub = SubscribeLoader(func, wrapper) + sub = wrapper(func) self._subscribers.append(sub) return sub def decorator(func1): self.plugin.validate(func1) - sub1 = SubscribeLoader(func1, wrapper) + sub1 = wrapper(func1) self._subscribers.append(sub1) return sub1 @@ -162,10 +158,6 @@ def inject(self, *requires: str): plugin._metadata.requirements.extend(requires) return self - def _load(self): - for disp in self.dispatchers.values(): - disp._load() - async def _startup(self): if Startup.__publisher__ in self.dispatchers: await self.dispatchers[Startup.__publisher__].publisher.emit(Startup()) @@ -178,8 +170,20 @@ async def _cleanup(self): if Cleanup.__publisher__ in self.dispatchers: await self.dispatchers[Cleanup.__publisher__].publisher.emit(Cleanup()) + def update_filter(self, allow: dict, deny: dict): + if not allow and not deny: + return + fter = Filter() + if allow: + fter = fter.and_(Filter.parse(allow)) + if deny: + fter = fter.not_(Filter.parse(deny)) + if fter.steps: + plugin_service.filters[self.id] = fter + def __post_init__(self): plugin_service.plugins[self.id] = self + self.update_filter(self.config.pop("$allow", {}), self.config.pop("$deny", {})) if self.id not in plugin_service._keep_values: plugin_service._keep_values[self.id] = {} if self.id not in plugin_service._referents: diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index b6bd109..2364ed0 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -10,10 +10,12 @@ from types import ModuleType from typing import Optional +from arclet.letoderea import global_auxiliaries + from ..config import EntariConfig from ..logger import log from .model import Plugin, PluginMetadata, _current_plugin -from .service import plugin_service +from .service import AccessAuxiliary, plugin_service _SUBMODULE_WAITLIST: dict[str, set[str]] = {} _ENSURE_IS_PLUGIN: set[str] = set() @@ -219,15 +221,17 @@ def create_module(self, spec) -> Optional[ModuleType]: return super().create_module(spec) def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = None) -> None: + is_sub = False if plugin := plugin_service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None: plugin.subplugins.add(module.__name__) plugin_service._subplugined[module.__name__] = plugin.id + is_sub = True if self.loaded: return # create plugin before executing - plugin = Plugin(module.__name__, module, config=config or {}) + plugin = Plugin(module.__name__, module, config=(config or {}).copy()) # for `dataclasses` module sys.modules[module.__name__] = plugin.proxy() # type: ignore setattr(module, "__plugin__", plugin) @@ -235,14 +239,20 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non setattr(module, "__getattr_or_import__", getattr_or_import) setattr(module, "__plugin_service__", plugin_service) + aux = AccessAuxiliary(plugin.id) + # enter plugin context with _current_plugin.use(plugin): try: + if not is_sub: + global_auxiliaries.append(aux) super().exec_module(module) except Exception: plugin.dispose() raise finally: + if not is_sub: + global_auxiliaries.remove(aux) # leave plugin context delattr(module, "__cached__") delattr(module, "__plugin_service__") diff --git a/arclet/entari/plugin/service.py b/arclet/entari/plugin/service.py index 0934a07..fad1324 100644 --- a/arclet/entari/plugin/service.py +++ b/arclet/entari/plugin/service.py @@ -1,10 +1,11 @@ from typing import TYPE_CHECKING, Any, Callable -from arclet.letoderea import es +from arclet.letoderea import JudgeAuxiliary, Scope, es from launart import Launart, Service from launart.status import Phase from ..event.lifespan import Cleanup, Ready, Startup +from ..filter import Filter from ..logger import log if TYPE_CHECKING: @@ -15,6 +16,7 @@ class PluginManagerService(Service): id = "entari.plugin.manager" plugins: dict[str, "Plugin"] + filters: dict[str, Filter] _keep_values: dict[str, dict[str, "KeepingVariable"]] _referents: dict[str, set[str]] _unloaded: set[str] @@ -29,6 +31,7 @@ def __init__(self): self._unloaded = set() self._subplugined = {} self._apply = {} + self.filters = {} @property def required(self) -> set[str]: @@ -43,7 +46,6 @@ async def launch(self, manager: Launart): for plug in self.plugins.values(): for serv in plug._services.values(): manager.add_component(serv) - plug._load() async with self.stage("preparing"): await es.publish(Startup()) @@ -69,3 +71,22 @@ async def launch(self, manager: Launart): plugin_service = PluginManagerService() + + +class AccessAuxiliary(JudgeAuxiliary): + def __init__(self, plugin_id: str): + super().__init__(priority=0) + self.plugin_id = plugin_id + + @property + def id(self): + return f"entari.plugin.access:{self.plugin_id}" + + @property + def scopes(self): + return {Scope.prepare} + + async def __call__(self, scope: Scope, interface): + if self.plugin_id in plugin_service.filters: + return await plugin_service.filters[self.plugin_id](scope, interface) + return True diff --git a/example_plugin.py b/example_plugin.py index 8af10e2..edc3654 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -33,7 +33,7 @@ async def cleanup(): @disp_message -@Filter().public.bind +@Filter().public().bind async def _(msg: MessageChain, session: Session): content = msg.extract_plain_text() if re.match(r"(.{0,3})(上传|设定)(.{0,3})(上传|设定)(.{0,3})", content): @@ -43,12 +43,12 @@ async def _(msg: MessageChain, session: Session): disp_message1 = plug.dispatch(MessageCreatedEvent) -@disp_message1.on(auxiliaries=[Filter().public.to_me.and_(lambda sess: str(sess.content) == "aaa")]) +@disp_message1.on(auxiliaries=[Filter().public().to_me().and_(lambda sess: str(sess.content) == "aaa")]) async def _(session: Session): return await session.send("Filter: public message, to me, and content is 'aaa'") -@disp_message1.on(auxiliaries=[Filter().public.to_me.not_(lambda sess: str(sess.content) == "aaa")]) +@disp_message1.on(auxiliaries=[Filter().public().to_me().not_(lambda sess: str(sess.content) == "aaa")]) async def _(session: Session): return await session.send("Filter: public message, to me, but content is not 'aaa'")