Skip to content

Commit

Permalink
✨ filter from config
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 13, 2024
1 parent 5abd664 commit 304486e
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 106 deletions.
28 changes: 0 additions & 28 deletions arclet/entari/_subscriber.py

This file was deleted.

64 changes: 46 additions & 18 deletions arclet/entari/builtins/auto_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ class Config:
logger = log.wrapper("[AutoReload]")


def detect_filter_change(old: dict, new: dict):
added = set(new) - set(old)
removed = set(old) - set(new)
changed = {key for key in set(new) & set(old) if new[key] != old[key]}
if "$allow" in removed:
allow = {}
else:
allow = new["$allow"]
if "$deny" in removed:
deny = {}
else:
deny = new["$deny"]
return allow, deny, not ((added | removed | changed) - {"$allow", "$deny"})


class Watcher(Service):
id = "watcher"

Expand Down Expand Up @@ -61,7 +76,6 @@ async def watch(self):
dispose_plugin(pid)
if plugin := load_plugin(pid):
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin
Expand All @@ -72,7 +86,6 @@ async def watch(self):
logger("INFO", f"Detected change in {change[1]!r} which failed to reload, retrying...")
if plugin := load_plugin(self.fail[change[1]]):
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin
Expand Down Expand Up @@ -102,37 +115,53 @@ async def watch_config(self):
f"Basic config <y>{key!r}</y> changed from <r>{old_basic[key]!r}</r> "
f"to <g>{EntariConfig.instance.basic[key]!r}</g>",
)
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key]))
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key], old_basic[key]))
for key in set(EntariConfig.instance.basic) - set(old_basic):
logger("DEBUG", f"Basic config <y>{key!r}</y> appended")
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key]))
for plugin_name in old_plugin:
pid = plugin_name.replace("::", "arclet.entari.builtins.")
if (
plugin_name not in EntariConfig.instance.plugin
or EntariConfig.instance.plugin[plugin_name] is False
) and (plugin := find_plugin(pid)):
await plugin._cleanup()
del plugin
dispose_plugin(pid)
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
):
if plugin := find_plugin(pid):
await plugin._cleanup()
del plugin
dispose_plugin(pid)
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
continue
if old_plugin[plugin_name] != EntariConfig.instance.plugin[plugin_name]:
logger(
"DEBUG",
f"Plugin <y>{plugin_name!r}</y> config changed from <r>{old_plugin[plugin_name]!r}</r> "
f"to <g>{EntariConfig.instance.plugin[plugin_name]!r}</g>",
)
res = await es.post(
ConfigReload("plugin", plugin_name, EntariConfig.instance.plugin[plugin_name])
)
if res and res.value:
logger("DEBUG", f"Plugin <y>{pid!r}</y> config change handled by itself.")
continue
if isinstance(old_plugin[plugin_name], bool):
old_conf = {}
else:
old_conf: dict = old_plugin[plugin_name] # type: ignore
if isinstance(EntariConfig.instance.plugin[plugin_name], bool):
new_conf = {}
else:
new_conf: dict = EntariConfig.instance.plugin[plugin_name] # type: ignore
if plugin := find_plugin(pid):
allow, deny, only_filter = detect_filter_change(old_conf, new_conf)
plugin.update_filter(allow, deny)
if only_filter:
logger("DEBUG", f"Plugin <y>{pid!r}</y> config only changed filter.")
continue
res = await es.post(
ConfigReload("plugin", plugin_name, new_conf, old_conf),
)
if res and res.value:
logger("DEBUG", f"Plugin <y>{pid!r}</y> config change handled by itself.")
continue
logger("INFO", f"Detected <blue>{pid!r}</blue>'s config change, reloading...")
plugin_file = str(plugin.module.__file__)
await plugin._cleanup()
dispose_plugin(plugin_name)
if plugin := load_plugin(plugin_name):
plugin._load()
if plugin := load_plugin(plugin_name, new_conf):
await plugin._startup()
await plugin._ready()
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
Expand All @@ -142,12 +171,11 @@ async def watch_config(self):
self.fail[plugin_file] = pid
else:
logger("INFO", f"Detected <blue>{pid!r}</blue> appended, loading...")
load_plugin(plugin_name)
load_plugin(plugin_name, new_conf)
if new := (set(EntariConfig.instance.plugin) - set(old_plugin)):
for plugin_name in new:
if not (plugin := load_plugin(plugin_name)):
continue
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin
Expand Down
54 changes: 30 additions & 24 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts):
if not msg:
return
if matches := list(self.trie.prefixes(msg)):
results = await asyncio.gather(*(res.value.handle(ctx.copy(), inner=True) for res in matches if res.value))
results = await asyncio.gather(*(res.value.handle(ctx.copy()) for res in matches if res.value))
for result in results:
if result is not None:
await session.send(result)
Expand All @@ -67,7 +67,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts):
command_manager.find_shortcut(get_cmd(value), data)
except ValueError:
continue
result = await value.handle(ctx.copy(), inner=True)
result = await value.handle(ctx.copy())
if result is not None:
await session.send(result)

Expand Down Expand Up @@ -143,7 +143,7 @@ def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern)
)
auxiliaries.insert(0, AlconnaSuppiler(_command))
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers)
self.publisher.remove_subscriber(target)
self.trie[key] = target

Expand All @@ -152,30 +152,34 @@ def _remove(_):
self.trie.pop(key, None) # type: ignore

target._dispose = _remove
return target

_command = cast(Alconna, command)
if not isinstance(command.command, str):
raise TypeError("Command name must be a string.")
_command.reset_namespace(self.__namespace__)
auxiliaries.insert(0, AlconnaSuppiler(_command))
keys = []
if not _command.prefixes:
keys.append(_command.command)
elif not all(isinstance(i, str) for i in _command.prefixes):
raise TypeError("Command prefixes must be a list of string.")
else:
auxiliaries.insert(0, AlconnaSuppiler(command))
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
self.publisher.remove_subscriber(target)
if not isinstance(command.command, str):
raise TypeError("Command name must be a string.")
keys = []
if not command.prefixes:
self.trie[command.command] = target
keys.append(command.command)
elif not all(isinstance(i, str) for i in command.prefixes):
raise TypeError("Command prefixes must be a list of string.")
else:
for prefix in cast(list[str], command.prefixes):
self.trie[prefix + command.command] = target
keys.append(prefix + command.command)
for prefix in cast(list[str], _command.prefixes):
keys.append(prefix + _command.command)

def _remove(_):
command_manager.delete(get_cmd(_))
for key in keys:
self.trie.pop(key, None) # type: ignore
target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers)
self.publisher.remove_subscriber(target)

target._dispose = _remove
command.reset_namespace(self.__namespace__)
for _key in keys:
self.trie[_key] = target

def _remove(_):
command_manager.delete(get_cmd(_))
for _key in keys:
self.trie.pop(_key, None) # type: ignore

target._dispose = _remove
return target

return wrapper
Expand Down Expand Up @@ -209,6 +213,8 @@ def _(plg: RootlessPlugin):
if "use_config_prefix" in plg.config:
_commands.judge.use_config_prefix = plg.config["use_config_prefix"]

plg.dispatch(MessageCreatedEvent).handle(_commands.handle, auxiliaries=[_commands.judge])

@plg.use(ConfigReload)
def update(event: ConfigReload):
if event.scope != "plugin":
Expand Down
4 changes: 1 addition & 3 deletions arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from arclet.alconna import Alconna, command_manager
from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory

from .._subscriber import SubscribeLoader
from ..event import MessageCreatedEvent
from ..event.command import pub as execute_handles
from ..plugin.model import Plugin, PluginDispatcher
Expand Down Expand Up @@ -60,8 +59,7 @@ def on_execute(
_auxiliaries.append(self.supplier)

def wrapper(func):
caller = execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)
sub = SubscribeLoader(func, caller)
sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers)
self._subscribers.append(sub)
return sub

Expand Down
20 changes: 15 additions & 5 deletions arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import suppress
import os

from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, es, global_providers
from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, Subscriber, es, global_providers
from creart import it
from launart import Launart, Service
from satori import LoginStatus
Expand All @@ -14,15 +14,14 @@
from satori.model import Event
from tarina.generic import get_origin

from .command import _commands
from .config import EntariConfig
from .event.config import ConfigReload
from .event.lifespan import AccountUpdate
from .event.protocol import MessageCreatedEvent, event_parse
from .event.send import SendResponse
from .logger import log
from .plugin import load_plugin, plugin_config, requires
from .plugin.model import RootlessPlugin
from .plugin.model import Plugin, RootlessPlugin
from .plugin.service import plugin_service
from .session import EntariProtocol, Session

Expand Down Expand Up @@ -55,7 +54,17 @@ async def __call__(self, context: Contexts):
return context["account"]


global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider()])
class PluginProvider(Provider[Plugin]):
async def __call__(self, context: Contexts):
subscriber: Subscriber = context["$subscriber"]
func = subscriber.callable_target
if hasattr(func, "__globals__") and "__plugin__" in func.__globals__: # type: ignore
return func.__globals__["__plugin__"]
if hasattr(func, "__module__"):
return plugin_service.plugins.get(func.__module__)


global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider(), PluginProvider()])


@RootlessPlugin.apply("record_message")
Expand Down Expand Up @@ -117,6 +126,8 @@ def __init__(
super().__init__(*configs, default_api_cls=EntariProtocol)
if not hasattr(EntariConfig, "instance"):
EntariConfig.load()
if "~commands" not in EntariConfig.instance.plugin:
EntariConfig.instance.plugin["~commands"] = True
log.set_level(log_level)
log.core.opt(colors=True).debug(f"Log level set to <y><c>{log_level}</c></y>")
requires(*EntariConfig.instance.plugin)
Expand All @@ -128,7 +139,6 @@ def __init__(
self._ref_tasks = set()

es.on(ConfigReload, self.reset_self)
es.on(MessageCreatedEvent, _commands.handle, auxiliaries=[_commands.judge])

def reset_self(self, scope, key, value):
if scope != "basic":
Expand Down
3 changes: 2 additions & 1 deletion arclet/entari/event/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional

from arclet.letoderea import es

Expand All @@ -11,6 +11,7 @@ class ConfigReload(BasedEvent):
scope: str
key: str
value: Any
old: Optional[Any] = None

__publisher__ = "entari.event/config_reload"
__result_type__: type[bool] = bool
Expand Down
Loading

0 comments on commit 304486e

Please sign in to comment.