-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c0aaa65
commit 5abd664
Showing
5 changed files
with
195 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,122 +1,71 @@ | ||
from collections.abc import Awaitable | ||
from typing import Callable, Optional, Union | ||
from typing_extensions import Self, TypeAlias | ||
import asyncio | ||
from datetime import datetime | ||
from typing import Optional, Union | ||
|
||
from arclet.letoderea import Interface, JudgeAuxiliary, Scope | ||
from arclet.letoderea import bind as _bind | ||
from arclet.letoderea.typing import run_sync | ||
from tarina import is_async | ||
|
||
from ..message import MessageChain | ||
from ..session import Session | ||
from .common import ChannelFilter, GuildFilter, PlatformFilter, SelfFilter, UserFilter | ||
from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger | ||
from .op import ExcludeFilter, IntersectFilter, UnionFilter | ||
from .common import Filter as Filter | ||
|
||
_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]] | ||
|
||
|
||
class Filter(JudgeAuxiliary): | ||
def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10): | ||
super().__init__(priority=priority) | ||
self.steps = [] | ||
if callback: | ||
if is_async(callback): | ||
self.callback = callback | ||
else: | ||
self.callback = run_sync(callback) | ||
else: | ||
self.callback = None | ||
|
||
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: | ||
for step in sorted(self.steps, key=lambda x: x.priority): | ||
if not await step(scope, interface): | ||
return False | ||
if self.callback: | ||
session = await interface.query(Session, "session", force_return=True) | ||
if not session: | ||
return False | ||
if not await self.callback(session): # type: ignore | ||
return False | ||
return True | ||
|
||
@property | ||
def scopes(self) -> set[Scope]: | ||
return {Scope.prepare} | ||
|
||
@property | ||
def id(self) -> str: | ||
return "entari.filter" | ||
|
||
def user(self, *user_ids: str) -> Self: | ||
self.steps.append(UserFilter(*user_ids, priority=6)) | ||
return self | ||
|
||
def guild(self, *guild_ids: str) -> Self: | ||
self.steps.append(GuildFilter(*guild_ids, priority=4)) | ||
return self | ||
|
||
def channel(self, *channel_ids: str) -> Self: | ||
self.steps.append(ChannelFilter(*channel_ids, priority=5)) | ||
return self | ||
|
||
def self(self, *self_ids: str) -> Self: | ||
self.steps.append(SelfFilter(*self_ids, priority=3)) | ||
return self | ||
|
||
def platform(self, *platforms: str) -> Self: | ||
self.steps.append(PlatformFilter(*platforms, priority=2)) | ||
return self | ||
class Interval(JudgeAuxiliary): | ||
def __init__(self, interval: float, limit_prompt: Optional[Union[str, MessageChain]] = None): | ||
self.success = True | ||
self.last_time = None | ||
self.interval = interval | ||
self.limit_prompt = limit_prompt | ||
super().__init__(priority=20) | ||
|
||
@property | ||
def direct(self) -> Self: | ||
self.steps.append(DirectMessageJudger(priority=8)) | ||
return self | ||
|
||
private = direct | ||
def id(self): | ||
return "entari.filter/interval" | ||
|
||
@property | ||
def public(self) -> Self: | ||
self.steps.append(PublicMessageJudger(priority=8)) | ||
return self | ||
def scopes(self): | ||
return {Scope.prepare, Scope.cleanup} | ||
|
||
@property | ||
def reply_me(self) -> Self: | ||
self.steps.append(ReplyMeJudger(priority=9)) | ||
return self | ||
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: | ||
if scope == Scope.prepare: | ||
if not self.last_time: | ||
return True | ||
# if self.condition: | ||
# if not await self.condition(scope, interface): | ||
# self.success = False | ||
# return False | ||
self.success = (datetime.now() - self.last_time).total_seconds() > self.interval | ||
if not self.success: | ||
session = await interface.query(Session, "session", force_return=True) | ||
if session and self.limit_prompt: | ||
await session.send(self.limit_prompt) | ||
return self.success | ||
if self.success: | ||
self.last_time = datetime.now() | ||
return True | ||
|
||
|
||
class Semaphore(JudgeAuxiliary): | ||
def __init__(self, count: int, limit_prompt: Optional[Union[str, MessageChain]] = None): | ||
self.count = count | ||
self.limit_prompt = limit_prompt | ||
self.semaphore = asyncio.Semaphore(count) | ||
super().__init__(priority=20) | ||
|
||
@property | ||
def notice_me(self) -> Self: | ||
self.steps.append(NoticeMeJudger(priority=10)) | ||
return self | ||
def id(self): | ||
return "entari.filter/access" | ||
|
||
@property | ||
def to_me(self) -> Self: | ||
self.steps.append(ToMeJudger(priority=11)) | ||
return self | ||
|
||
def bind(self, func): | ||
return _bind(self)(func) | ||
|
||
def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter": | ||
new = Filter(priority=self.priority) | ||
_other = other if isinstance(other, Filter) else Filter(callback=other) | ||
new.steps.append(IntersectFilter(self, _other, priority=1)) | ||
return new | ||
def scopes(self): | ||
return {Scope.prepare, Scope.cleanup} | ||
|
||
intersect = and_ | ||
|
||
def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter": | ||
new = Filter(priority=self.priority) | ||
_other = other if isinstance(other, Filter) else Filter(callback=other) | ||
new.steps.append(UnionFilter(self, _other, priority=1)) | ||
return new | ||
|
||
union = or_ | ||
|
||
def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter": | ||
new = Filter(priority=self.priority) | ||
_other = other if isinstance(other, Filter) else Filter(callback=other) | ||
new.steps.append(ExcludeFilter(self, _other, priority=1)) | ||
return new | ||
|
||
exclude = not_ | ||
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: | ||
if scope == Scope.prepare: | ||
if not await self.semaphore.acquire(): | ||
session = await interface.query(Session, "session", force_return=True) | ||
if session and self.limit_prompt: | ||
await session.send(self.limit_prompt) | ||
return False | ||
return True | ||
self.semaphore.release() | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters