diff --git a/bot.py b/bot.py index dc2ce91..fafdd97 100644 --- a/bot.py +++ b/bot.py @@ -2,7 +2,7 @@ import discord from activities import activity from discord.ext import tasks -from modules.config import fetch_config +from modules.config import fetch_configs CHANGE_STATUS_INTERVAL_HOURS = 1 @@ -15,7 +15,7 @@ def __init__(self, is_prod: bool, *args, **kwargs) -> None: movie_activity = activity.get_random_activity_as_discordpy_activity() # config self.is_prod = is_prod - self.guilds_dict = fetch_config(is_prod) + self.guilds_dict = fetch_configs(is_prod) super().__init__(intents=intents, activity=movie_activity, *args, **kwargs) @tasks.loop(hours=CHANGE_STATUS_INTERVAL_HOURS) @@ -32,7 +32,7 @@ async def setup_hook(self) -> None: self.set_activity.start() def fetch_config(self): - self.guilds_dict = fetch_config(self.is_prod) + self.guilds_dict = fetch_configs(self.is_prod) def get_bot(is_prod: bool): diff --git a/main.py b/main.py index d91b3b7..8b3318a 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ import os import discord -from discord import app_commands # from discord.ext import commands from dotenv import load_dotenv @@ -37,7 +36,6 @@ from mentioned import mention_responses -from modules.supabase import supabaseClient from modules import config import bot import sys @@ -55,7 +53,7 @@ class BotInitialiser: def __init__(self): self.client = bot.get_bot(IS_PROD) - self.guilds = [discord.Object(id=int(server_id)) for server_id in self.client.guilds_dict.keys()] + self.guilds = [discord.Object(id=guild_id) for guild_id, guild_config in self.client.guilds_dict.items() if guild_config.get('prod_config') == IS_PROD] # CommandTree is where all our defined commands are stored self.tree = discord.app_commands.CommandTree(self.client) self.placetw_guild = discord.Object(id=os.getenv("PLACETW_SERVER_ID")) # basically refers to this server @@ -117,8 +115,10 @@ def register_event_callbacks(self): # sync the slash commands servers when the bot is ready @self.client.event async def on_ready(): - for guild_id in self.client.guilds_dict.keys(): - guild = discord.Object(id=guild_id) + self.tree.clear_commands(guild=None) + await self.tree.sync() + + for guild in self.guilds: await self.tree.sync(guild=guild) # Enable logging logging.init(self.client, DEPLOYMENT_DATE) @@ -151,15 +151,17 @@ async def on_message(message: discord.Message): await logging.log_message_event(message, events) @self.client.event - async def on_guild_join(guild): - supabaseClient.table("server_config").insert( - { - "guild_id": str(guild.id), - "server_name": guild.name, - } - ).execute() - self.register_commands(self.tree, self.client, [guild]) - await self.tree.sync(guild=guild) + async def on_guild_join(guild: discord.Guild): + print(f"Guild {guild.name} ({guild.id}) joined") + config.create_new_config(guild.id, guild.name, IS_PROD) + await self.tree.sync(guild=guild) + + @self.client.event + async def on_guild_remove(guild: discord.Guild): + print(f"Guild {guild.name} ({guild.id}) removed") + self.guilds.remove(discord.Object(id=guild.id)) + del self.client.guilds_dict[guild.id] + config.remove_config(guild.id, IS_PROD) def run(self): self.client.run(TOKEN) diff --git a/modules/config.py b/modules/config.py index c57b01e..ec90fec 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,7 +1,7 @@ from modules.supabase import supabaseClient -def fetch_config(is_prod: bool): +def fetch_configs(is_prod: bool): supabase_data = supabaseClient.table("server_config").select("*").eq('prod_config', is_prod).execute().data transformed_dict = { item['guild_id']: {key: value for key, value in item.items() if key != 'guild_id'} for item in supabase_data @@ -9,7 +9,44 @@ def fetch_config(is_prod: bool): return transformed_dict -def set_config(guild_id: int, key: str, value: str, is_prod: bool): - supabaseClient.table("server_config").update({key: value}).eq("guild_id", guild_id).eq( - "prod_config", is_prod - ).execute() +def fetch_guild_config(guild_id: int, is_prod: bool): + response = ( + supabaseClient.table("server_config").select("*").eq("guild_id", guild_id).eq("prod_config", is_prod).execute() + ) + # if there is no data, return an empty dict + if len(response.data) == 0: + return {} + return response.data[0] + + +def create_new_config(guild_id: int, server_name: str, is_prod: bool): + response = ( + supabaseClient.table("server_config") + .insert( + { + "guild_id": guild_id, + "server_name": server_name, + "prod_config": is_prod, + } + ) + .execute() + ) + return response.data + + +def set_config(guild_id: int, key: str, value: str, is_prod: bool) -> list: + response = ( + supabaseClient.table("server_config") + .update({key: value}) + .eq("guild_id", guild_id) + .eq("prod_config", is_prod) + .execute() + ) + return response.data + + +def remove_config(guild_id: int, is_prod: bool) -> list: + response = ( + supabaseClient.table("server_config").delete().eq("guild_id", guild_id).eq("prod_config", is_prod).execute() + ) + return response.data diff --git a/tests/test_deploy.py b/tests/test_deploy.py index f7bdba8..c8b1e0c 100644 --- a/tests/test_deploy.py +++ b/tests/test_deploy.py @@ -21,11 +21,11 @@ def mock_create_client(url: str, private_key: str): # patch the modules.config module, which uses supabaseClient import modules.config - # patch the module.config.fetch_config to return a dict instead of fetching from supabase - def mock_fetch_config(*args, **kwargs): + # patch the module.config.fetch_configs to return a dict instead of fetching from supabase + def mock_fetch_configs(*args, **kwargs): return {0: {"key": "value"}} - monkeypatch.setattr(modules.config, "fetch_config", mock_fetch_config) + monkeypatch.setattr(modules.config, "fetch_configs", mock_fetch_configs) # patch the module.config.set_config to do nothing instead of setting the config in supabase def mock_set_config(*args, **kwargs):