Skip to content

Commit

Permalink
reafactor: migrate from raw sql to orm (#24)
Browse files Browse the repository at this point in the history
* Migrate from raw sql (aiosqlite) to ORM (sqlalchemy)
* Adds autocomplete in extension manager command group 
* Fix music playlist autocomplete not sorting according to name
* Increase Mafic logging to warning
  • Loading branch information
KnownBlackHat authored Jul 20, 2024
1 parent c90f9aa commit 0813532
Show file tree
Hide file tree
Showing 31 changed files with 1,601 additions and 1,305 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Servers.inf
lavalink_server/logs
ip.txt
*.db
/logs
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ repos:
hooks:
- id: isort
exclude: '^(env)'
args: ["--profile", "black"]
2 changes: 1 addition & 1 deletion lavalink_application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins:

lavalink:
plugins:
- dependency: "dev.lavalink.youtube:youtube-plugin:1.3.0"
- dependency: "dev.lavalink.youtube:youtube-plugin:1.4.0"
snapshot: false
server:
sources:
Expand Down
77 changes: 29 additions & 48 deletions mr_robot/__main__.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,51 @@
import asyncio
import atexit
import json
import logging.config
import logging.handlers
import os
import signal
import sys
from pathlib import Path

import aiosqlite
import disnake
import httpx
from dotenv import load_dotenv
from sqlalchemy import event
from sqlalchemy.engine import Engine

import mr_robot.log
from mr_robot.bot import MrRobot
from mr_robot.constants import Client
from mr_robot.constants import Client, Database

load_dotenv()


def setup_logging_modern() -> None:
with open(Client.logging_config_file, "r") as file:
config = json.load(file)
try:
os.mkdir("logs")
except FileExistsError:
...
logging.config.dictConfig(config)
queue_handler = logging.getHandlerByName("queue_handler")
if queue_handler is not None:
queue_handler.listener.start() # type: ignore[reportAttributeAccessIssue]
atexit.register(queue_handler.listener.stop) # type: ignore[reportAttributeAccessIssue]


def setup_logging() -> None:
os.makedirs("logs", exist_ok=True)
file_handler = logging.handlers.RotatingFileHandler(
Client.log_file_name, mode="a", maxBytes=(1000000 * 20), backupCount=5
)
console_handler = logging.StreamHandler()

file_handler.setLevel(logging.DEBUG)
console_handler.setLevel(logging.INFO)
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s|%(module)s|%(funcName)s|L%(lineno)d] %(asctime)s: %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
handlers=[console_handler, file_handler],
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("disnake").setLevel(logging.INFO)
logging.getLogger("aiosqlite").setLevel(logging.INFO)
logging.getLogger("streamlink").disabled = True
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, _):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


async def main():
setup_logging()
mr_robot.log.setup_logging()
logger = logging.getLogger(Client.name)
logger.info("Logger Initialized!")
async with httpx.AsyncClient(timeout=httpx.Timeout(None)) as session:
client = MrRobot(
intents=disnake.Intents.all(),
session=session,
db=await aiosqlite.connect(Client.db_name),
http_session=session,
)
await client.init_db()
if client.git:
logger.info("Pulling DB")
await client.git.pull(Client.db_name)
try:
if not Path(Database.db_name).exists():
logger.info("Pulling DB")
client.db_exsists = False
await client.git.pull(Database.db_name)
else:
logger.info("Db file found!")
client.db_exsists = True
except httpx.HTTPStatusError:
logger.warning(f"Failed to pull {Database.db_name} from github.")
except (httpx.ConnectError, httpx.ConnectTimeout):
logger.error("Failed to connect with github", exc_info=True)
await client.close()

try:
client.load_bot_extensions()
except Exception:
Expand All @@ -86,8 +65,10 @@ async def main():
except asyncio.CancelledError:
logger.info("Received signal to terminate bot and event loop")
finally:
logger.warning("Closing Client")
if not client.is_closed():
await client.close()
exit()


if __name__ == "__main__":
Expand Down
38 changes: 30 additions & 8 deletions mr_robot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import httpx
import mafic
from aiocache import cached
from aiosqlite import Connection
from disnake.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from mr_robot.constants import Client, Lavalink
from mr_robot.constants import Client, Database, Lavalink
from mr_robot.database import Base
from mr_robot.utils.extensions import walk_extensions
from mr_robot.utils.git_api import Git

Expand All @@ -19,17 +20,20 @@
class MrRobot(commands.AutoShardedInteractionBot):
"""Mr Robot Bot"""

def __init__(self, session: httpx.AsyncClient, db: Connection, **kwargs):
def __init__(self, http_session: httpx.AsyncClient, **kwargs):
super().__init__(**kwargs)
self.pool = mafic.NodePool(self)
self.loop.create_task(self.add_nodes())
self.start_time = time.time()
self.session = session
self.db = db
self.http_session = http_session
self.token = Client.github_token
self.repo = Client.github_db_repo
self.git = None
self.db_name = Client.db_name
self.db_exsists = True
self.db_engine = create_async_engine(Database.uri)
self.db_session = async_sessionmaker(
self.db_engine, expire_on_commit=False, class_=AsyncSession
)
if self.token and self.repo:
owner, repo = self.repo.split("/")
self.git = Git(
Expand All @@ -38,13 +42,31 @@ def __init__(self, session: httpx.AsyncClient, db: Connection, **kwargs):
repo=repo,
username=Client.name,
email=f"{Client.name}@mr_robot_discord_bot.com",
client=session,
client=http_session,
)
logger.info("Mr Robot is ready")

@property
def db(self) -> async_sessionmaker[AsyncSession]:
"""Alias of bot.db_session"""
return self.db_session

async def init_db(self) -> None:
"""Initializes the database"""
async with self.db_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

async def close(self) -> None:
"""Close session when bot is shutting down"""
await super().close()
if self.db_engine:
await self.db_engine.dispose()
if self.http_session:
await self.http_session.aclose()

@cached(ttl=60 * 60 * 12)
async def _request(self, url: str) -> Dict | List:
resp = await self.session.get(url, headers={"User-Agent": "Magic Browser"})
resp = await self.http_session.get(url, headers={"User-Agent": "Magic Browser"})
logger.info(f"HTTP Get: {resp.status_code} {url}")
if resp.status_code == 200:
return resp.json()
Expand Down
5 changes: 3 additions & 2 deletions mr_robot/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

def ensure_voice_connect() -> Callable:
def predicate(interaction: disnake.GuildCommandInteraction) -> bool:
if interaction.author.voice is None:
if interaction.author.voice is None or interaction.author.voice.channel is None:
raise commands.CommandError("You aren't connected to voice channel")
elif (
interaction.guild.voice_client is not None
and interaction.author.voice != interaction.guild.voice_client.channel.id
and interaction.author.voice.channel.id
!= interaction.guild.voice_client.channel.id
):
raise commands.CommandError("You must be in the same voice channel.")
return True
Expand Down
6 changes: 5 additions & 1 deletion mr_robot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ class Client:
name = "Mr Robot"
token = getenv("BOT_TOKEN")
log_file_name = "logs/info.log"
db_name = "mr-robot.db"
github_db_repo = getenv("GIT_DB_REPO")
github_token = getenv("GIT_TOKEN")
github_bot_repo = "https://github.com/Mr-Robot-Discord-Bot/Mr-Robot/"
Expand Down Expand Up @@ -45,3 +44,8 @@ class Colors:
black = 0x000000
orange = 0xFFA500
yellow = 0xFFFF00


class Database:
db_name = "mr_robot.db"
uri = f"sqlite+aiosqlite:///{db_name}"
17 changes: 17 additions & 0 deletions mr_robot/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .base import Base
from .greeter import Greeter
from .guild import Guild
from .music import Playlists, Tracks
from .temprole import TempRole
from .ticketsystem import Ticket, TicketConfig

__all__ = [
"Greeter",
"Guild",
"Base",
"TempRole",
"Playlists",
"Tracks",
"TicketConfig",
"Ticket",
]
6 changes: 6 additions & 0 deletions mr_robot/database/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from sqlalchemy.orm import DeclarativeBase


class Base(DeclarativeBase):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}({', '.join([f"{col}={getattr(self, col, None)}" for col in self.__table__.columns.keys()])})"
44 changes: 44 additions & 0 deletions mr_robot/database/greeter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional

import sqlalchemy
from sqlalchemy.orm import Mapped, mapped_column, relationship

from .base import Base
from .guild import Guild


class Greeter(Base):
__tablename__ = "greeters"

id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=False, index=True
)
guild_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("guilds.id", ondelete="CASCADE"), index=True
)
wlcm_channel: Mapped[Optional[int]] = mapped_column(
sqlalchemy.BigInteger, nullable=True
)
wlcm_image: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
wlcm_theme: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
wlcm_fontstyle: Mapped[Optional[str]] = mapped_column(
sqlalchemy.String, nullable=True
)
wlcm_outline: Mapped[Optional[int]] = mapped_column(
sqlalchemy.Integer, nullable=True
)
wlcm_msg: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
bye_channel: Mapped[Optional[int]] = mapped_column(
sqlalchemy.BigInteger, nullable=True
)
bye_image: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
bye_theme: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
bye_fontstyle: Mapped[Optional[str]] = mapped_column(
sqlalchemy.String, nullable=True
)
bye_outline: Mapped[Optional[int]] = mapped_column(
sqlalchemy.Integer, nullable=True
)
bye_msg: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)

guild: Mapped[Guild] = relationship("Guild", passive_deletes=True)
13 changes: 13 additions & 0 deletions mr_robot/database/guild.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import sqlalchemy
from sqlalchemy.orm import Mapped, mapped_column

from .base import Base


class Guild(Base):
__tablename__ = "guilds"

id: Mapped[int] = mapped_column(
sqlalchemy.BigInteger, primary_key=True, autoincrement=False, index=True
)
name: Mapped[str] = mapped_column(sqlalchemy.String(length=50))
33 changes: 33 additions & 0 deletions mr_robot/database/music.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import List

import sqlalchemy
from sqlalchemy.orm import Mapped, mapped_column, relationship

from .base import Base


class Playlists(Base):
__tablename__ = "playlists"

id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(sqlalchemy.String(length=50))
user_id: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
tracks: Mapped[List[Tracks]] = relationship(
"Tracks", back_populates="playlists", cascade="all, delete", lazy="selectin"
)


class Tracks(Base):
__tablename__ = "tracks"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, index=True)

playlist_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("playlists.id", ondelete="CASCADE"), index=True
)
track: Mapped[str] = mapped_column(sqlalchemy.Text, index=True)

playlists: Mapped[Playlists] = relationship(
"Playlists", back_populates="tracks", passive_deletes=True, lazy="selectin"
)
20 changes: 20 additions & 0 deletions mr_robot/database/temprole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sqlalchemy
from sqlalchemy.orm import Mapped, mapped_column, relationship

from .base import Base
from .guild import Guild


class TempRole(Base):
__tablename__ = "temprole"

id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, index=True)
guild_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("guilds.id", ondelete="CASCADE"), index=True
)
user_id: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)
role_id: Mapped[int] = mapped_column(sqlalchemy.BigInteger)
expiration: Mapped[str] = mapped_column(sqlalchemy.Text)

# Relationship
guild: Mapped[Guild] = relationship("Guild", passive_deletes=True)
Loading

0 comments on commit 0813532

Please sign in to comment.