Skip to content

Commit

Permalink
✨ scheduler & access filter
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 13, 2024
1 parent c0aaa65 commit 5abd664
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 118 deletions.
161 changes: 55 additions & 106 deletions arclet/entari/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,71 @@
from collections.abc import Awaitable
from typing import Callable, Optional, Union
from typing_extensions import Self, TypeAlias
import asyncio
from datetime import datetime
from typing import Optional, Union

from arclet.letoderea import Interface, JudgeAuxiliary, Scope
from arclet.letoderea import bind as _bind
from arclet.letoderea.typing import run_sync
from tarina import is_async

from ..message import MessageChain
from ..session import Session
from .common import ChannelFilter, GuildFilter, PlatformFilter, SelfFilter, UserFilter
from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger
from .op import ExcludeFilter, IntersectFilter, UnionFilter
from .common import Filter as Filter

_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]


class Filter(JudgeAuxiliary):
def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10):
super().__init__(priority=priority)
self.steps = []
if callback:
if is_async(callback):
self.callback = callback
else:
self.callback = run_sync(callback)
else:
self.callback = None

async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]:
for step in sorted(self.steps, key=lambda x: x.priority):
if not await step(scope, interface):
return False
if self.callback:
session = await interface.query(Session, "session", force_return=True)
if not session:
return False
if not await self.callback(session): # type: ignore
return False
return True

@property
def scopes(self) -> set[Scope]:
return {Scope.prepare}

@property
def id(self) -> str:
return "entari.filter"

def user(self, *user_ids: str) -> Self:
self.steps.append(UserFilter(*user_ids, priority=6))
return self

def guild(self, *guild_ids: str) -> Self:
self.steps.append(GuildFilter(*guild_ids, priority=4))
return self

def channel(self, *channel_ids: str) -> Self:
self.steps.append(ChannelFilter(*channel_ids, priority=5))
return self

def self(self, *self_ids: str) -> Self:
self.steps.append(SelfFilter(*self_ids, priority=3))
return self

def platform(self, *platforms: str) -> Self:
self.steps.append(PlatformFilter(*platforms, priority=2))
return self
class Interval(JudgeAuxiliary):
def __init__(self, interval: float, limit_prompt: Optional[Union[str, MessageChain]] = None):
self.success = True
self.last_time = None
self.interval = interval
self.limit_prompt = limit_prompt
super().__init__(priority=20)

@property
def direct(self) -> Self:
self.steps.append(DirectMessageJudger(priority=8))
return self

private = direct
def id(self):
return "entari.filter/interval"

@property
def public(self) -> Self:
self.steps.append(PublicMessageJudger(priority=8))
return self
def scopes(self):
return {Scope.prepare, Scope.cleanup}

@property
def reply_me(self) -> Self:
self.steps.append(ReplyMeJudger(priority=9))
return self
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]:
if scope == Scope.prepare:
if not self.last_time:
return True
# if self.condition:
# if not await self.condition(scope, interface):
# self.success = False
# return False
self.success = (datetime.now() - self.last_time).total_seconds() > self.interval
if not self.success:
session = await interface.query(Session, "session", force_return=True)
if session and self.limit_prompt:
await session.send(self.limit_prompt)
return self.success
if self.success:
self.last_time = datetime.now()
return True


class Semaphore(JudgeAuxiliary):
def __init__(self, count: int, limit_prompt: Optional[Union[str, MessageChain]] = None):
self.count = count
self.limit_prompt = limit_prompt
self.semaphore = asyncio.Semaphore(count)
super().__init__(priority=20)

@property
def notice_me(self) -> Self:
self.steps.append(NoticeMeJudger(priority=10))
return self
def id(self):
return "entari.filter/access"

@property
def to_me(self) -> Self:
self.steps.append(ToMeJudger(priority=11))
return self

def bind(self, func):
return _bind(self)(func)

def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(IntersectFilter(self, _other, priority=1))
return new
def scopes(self):
return {Scope.prepare, Scope.cleanup}

intersect = and_

def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(UnionFilter(self, _other, priority=1))
return new

union = or_

def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(ExcludeFilter(self, _other, priority=1))
return new

exclude = not_
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]:
if scope == Scope.prepare:
if not await self.semaphore.acquire():
session = await interface.query(Session, "session", force_return=True)
if session and self.limit_prompt:
await session.send(self.limit_prompt)
return False
return True
self.semaphore.release()
return True
121 changes: 120 additions & 1 deletion arclet/entari/filter/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Optional
from collections.abc import Awaitable
from typing import Callable, Optional, Union
from typing_extensions import Self, TypeAlias

from arclet.letoderea import Interface, JudgeAuxiliary, Scope
from arclet.letoderea import bind as _bind
from arclet.letoderea.typing import run_sync
from satori import Channel, Guild, User
from satori.client import Account
from tarina import is_async

from ..session import Session
from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger
from .op import ExcludeFilter, IntersectFilter, UnionFilter


class UserFilter(JudgeAuxiliary):
Expand Down Expand Up @@ -98,3 +107,113 @@ def scopes(self) -> set[Scope]:
@property
def id(self) -> str:
return "entari.filter/platform"


_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]


class Filter(JudgeAuxiliary):
def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10):
super().__init__(priority=priority)
self.steps = []
if callback:
if is_async(callback):
self.callback = callback
else:
self.callback = run_sync(callback)
else:
self.callback = None

async def __call__(self, scope: Scope, interface: Interface):
for step in sorted(self.steps, key=lambda x: x.priority):
if not await step(scope, interface):
return False
if self.callback:
session = await interface.query(Session, "session", force_return=True)
if not session:
return False
if not await self.callback(session): # type: ignore
return False
return True

@property
def scopes(self) -> set[Scope]:
return {Scope.prepare}

@property
def id(self) -> str:
return "entari.filter"

def user(self, *user_ids: str) -> Self:
self.steps.append(UserFilter(*user_ids, priority=6))
return self

def guild(self, *guild_ids: str) -> Self:
self.steps.append(GuildFilter(*guild_ids, priority=4))
return self

def channel(self, *channel_ids: str) -> Self:
self.steps.append(ChannelFilter(*channel_ids, priority=5))
return self

def self(self, *self_ids: str) -> Self:
self.steps.append(SelfFilter(*self_ids, priority=3))
return self

def platform(self, *platforms: str) -> Self:
self.steps.append(PlatformFilter(*platforms, priority=2))
return self

@property
def direct(self) -> Self:
self.steps.append(DirectMessageJudger(priority=8))
return self

private = direct

@property
def public(self) -> Self:
self.steps.append(PublicMessageJudger(priority=8))
return self

@property
def reply_me(self) -> Self:
self.steps.append(ReplyMeJudger(priority=9))
return self

@property
def notice_me(self) -> Self:
self.steps.append(NoticeMeJudger(priority=10))
return self

@property
def to_me(self) -> Self:
self.steps.append(ToMeJudger(priority=11))
return self

def bind(self, func):
return _bind(self)(func)

def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(IntersectFilter(self, _other, priority=1))
return new

intersect = and_

def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(UnionFilter(self, _other, priority=1))
return new

union = or_

def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
new = Filter(priority=self.priority)
_other = other if isinstance(other, Filter) else Filter(callback=other)
new.steps.append(ExcludeFilter(self, _other, priority=1))
return new

exclude = not_
15 changes: 6 additions & 9 deletions arclet/entari/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ async def launch(self, manager: Launart):
id = "entari.scheduler"


scheduler = Scheduler()
scheduler = service = Scheduler()


@RootlessPlugin.apply("scheduler")
def _(plg: RootlessPlugin):
plg.service(scheduler)
plg.service(service)


def every_second():
Expand Down Expand Up @@ -186,16 +186,13 @@ def crontab(cron_str: str):
cron_str (str): cron 表达式
"""

def _():
now = datetime.now()
it = croniter(cron_str, now)
return it.get_next(datetime) - now
it = croniter(cron_str, datetime.now())

return _
return lambda iter=it: iter.get_next(datetime) - datetime.now()


def cron(pattern: str):
return scheduler.schedule(crontab(pattern))
return service.schedule(crontab(pattern))


def every(
Expand All @@ -207,4 +204,4 @@ def every(
"minute": every_minutes,
"hour": every_hours,
}
return scheduler.schedule(_TIMER_MAPPING[mode](value))
return service.schedule(_TIMER_MAPPING[mode](value))
3 changes: 2 additions & 1 deletion example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ plugins:
watch_dirs: ["."]
::echo: true
example_plugin: true
~record_message: true
~record_message: true
~scheduler: true
13 changes: 12 additions & 1 deletion example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
command,
metadata,
keeping,
scheduler,
Entari,
)
from arclet.entari.filter import Interval

metadata(__file__)

Expand Down Expand Up @@ -50,7 +53,7 @@ async def _(session: Session):
return await session.send("Filter: public message, to me, but content is not 'aaa'")


@command.on("add {a} {b}")
@command.on("add {a} {b}", [Interval(2, limit_prompt="太快了")])
def add(a: int, b: int):
return f"{a + b =}"

Expand Down Expand Up @@ -80,3 +83,11 @@ async def show(session: Session):
@plug.use("::before_send")
async def send_hook(message: MessageChain):
return message + "喵"


@scheduler.cron("* * * * *")
async def broadcast(app: Entari):
for account in app.accounts.values():
channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data]
for channel in channels:
await account.send_message(channel, "Hello, World!")

0 comments on commit 5abd664

Please sign in to comment.