Skip to content

Commit

Permalink
Merge pull request #150 from EduardSchwarzkopf/fix-isolate-db-sessions
Browse files Browse the repository at this point in the history
Fix isolate db sessions
  • Loading branch information
EduardSchwarzkopf authored Oct 8, 2024
2 parents 9fbb24e + a5694ce commit c6b1f14
Show file tree
Hide file tree
Showing 25 changed files with 279 additions and 400 deletions.
14 changes: 7 additions & 7 deletions app/authentication/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from functools import lru_cache
from typing import AsyncGenerator

from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase

from app import models
from app.authentication.management import UserManager
from app.authentication.strategies import JWTAccessRefreshStrategy
from app.config import settings
from app.database import get_user_db
from app.database import SessionLocal


async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator[UserManager, None]:
async def get_user_manager() -> AsyncGenerator[UserManager, None]:
"""
Asynchronously yields a UserManager for the provided user.
Expand All @@ -26,7 +24,9 @@ async def get_user_manager(
UserInactive: If the user is inactive.
UserAlreadyVerified: If the user is already verified.
"""
yield UserManager(user_db)
async with SessionLocal() as session:
user_db = SQLAlchemyUserDatabase(session, models.User, models.OAuthAccount)
yield UserManager(user_db)


@lru_cache
Expand Down
90 changes: 3 additions & 87 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,6 @@
from typing import Optional

from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

from app.config import settings


class Database:
def __init__(self, url: str) -> None:
"""
Initialize the database connection.
Args:
url (str): The URL of the database.
Raises:
None
"""

self.url = url
self.engine = create_async_engine(url, future=True)
self._session: Optional[AsyncSession] = None

async def init(self):
"""
Initialize the database connection.
Raises:
None
"""

self._session = self.get_session()

def get_session(self) -> AsyncSession:
"""
Get the asynchronous session for the database.
Returns:
AsyncSession: The asynchronous session for the database.
Raises:
RuntimeError: If the engine has not been initialized.
"""

if self.engine is None:
raise RuntimeError("Engine has not been initialized")

session_factory = sessionmaker(
self.engine, expire_on_commit=False, class_=AsyncSession
)

return session_factory()

@property
def session(self) -> AsyncSession:
"""
Get the asynchronous session for the database.
Returns:
AsyncSession: The asynchronous session for the database.
Raises:
RuntimeError: If the database session has not been initialized.
"""

if self._session is None:
raise RuntimeError("Database session has not been initialized.")
return self._session


async def get_user_db():
"""Get the user database.
Args:
None
Returns:
SQLAlchemyUserDatabase: The user database.
Raises:
None
"""
from app.models import OAuthAccount, User # pylint: disable=import-outside-toplevel

yield SQLAlchemyUserDatabase(db.session, User, OAuthAccount)


db: Database = Database(settings.db_url)
engine = create_async_engine(settings.db_url, future=True)
SessionLocal = async_sessionmaker(expire_on_commit=False, bind=engine)
4 changes: 0 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from app.authentication.dependencies import get_strategy
from app.authentication.strategies import JWTAccessRefreshStrategy
from app.config import settings
from app.database import db
from app.exception_handler import (
bad_request_exception_handler,
entity_access_denied_exception_handler,
Expand Down Expand Up @@ -69,13 +68,10 @@ async def lifespan(_api_app: FastAPI):
"""

try:
await db.init()
add_jobs_to_scheduler(scheduler)
yield
finally:
scheduler.shutdown()
if db.session is not None:
await db.session.close()


app = FastAPI(lifespan=lifespan)
Expand Down
63 changes: 30 additions & 33 deletions app/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

from sqlalchemy import Select, exists, func, text
from sqlalchemy import update as sql_update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute

from app import models
from app.database import db
from app.database import SessionLocal
from app.exceptions.base_service_exception import EntityNotFoundException
from app.utils.enums import DatabaseFilterOperator, Frequency
from app.utils.fields import IdField
Expand All @@ -18,9 +17,6 @@

class Repository:

def __init__(self, session: Optional[AsyncSession] = None):
self.session = session or db.session

def _load_relationships(
self, query: Select, relationships: InstrumentedAttribute = None
) -> Select:
Expand Down Expand Up @@ -54,7 +50,8 @@ async def get_all(
"""
q = select(cls)
q = self._load_relationships(q, load_relationships_list)
result = await self.session.execute(q)
async with SessionLocal() as session:
result = await session.execute(q)
return result.unique().scalars().all()

async def get(
Expand All @@ -77,7 +74,10 @@ async def get(
"""
q = select(cls).where(cls.id == instance_id)
q = self._load_relationships(q, load_relationships_list)
result = await self.session.execute(q)

async with SessionLocal() as session:
result = await session.execute(q)

model = result.scalars().first()

if model is None:
Expand Down Expand Up @@ -118,7 +118,8 @@ async def filter_by(
q = select(cls).where(condition).params(val=value)
q = self._load_relationships(q, load_relationships_list)

result = await self.session.execute(q)
async with SessionLocal() as session:
result = await session.execute(q)

return result.unique().scalars().all()

Expand Down Expand Up @@ -165,7 +166,8 @@ async def filter_by_multiple(
q = q.params(**params)
q = self._load_relationships(q, load_relationships_list)

result = await self.session.execute(q)
async with SessionLocal() as session:
result = await session.execute(q)

return result.scalars().unique().all()

Expand Down Expand Up @@ -220,7 +222,8 @@ def get_period_start_date(frequency_id: int) -> datetime:
transaction_exists_condition,
)

result = await self.session.execute(query)
async with SessionLocal() as session:
result = await session.execute(query)
return result.scalars().all()

async def get_transactions_from_period(
Expand Down Expand Up @@ -252,7 +255,8 @@ async def get_transactions_from_period(
.filter(wallet_id == transaction.wallet_id)
)

result = await self.session.execute(query)
async with SessionLocal() as session:
result = await session.execute(query)
return result.scalars().all()

async def save(self, obj: Union[ModelT, List[ModelT]]) -> None:
Expand All @@ -268,13 +272,14 @@ async def save(self, obj: Union[ModelT, List[ModelT]]) -> None:
None
"""

if isinstance(obj, list):
self.session.add_all(obj)
return
async with SessionLocal() as session, session.begin():
if isinstance(obj, list):
session.add_all(obj)
return

self.session.add(obj)
session.add(obj)

async def update(self, cls: Type[ModelT], instance_id: int, **kwargs) -> None:
async def update(self, cls: Type[ModelT], instance_id: int, **kwargs):
"""Update an instance of the specified model with the given ID.
Args:
Expand All @@ -294,29 +299,19 @@ async def update(self, cls: Type[ModelT], instance_id: int, **kwargs) -> None:
.values(**kwargs)
.execution_options(synchronize_session="fetch")
)
await self.session.execute(query)
async with SessionLocal() as session:
await session.execute(query)

async def delete(self, obj: Type[ModelT]) -> None:
"""Delete an object from the database.
Args:
obj: The object to delete.
Returns:
None
Raises:
None
"""

# TODO: Test this
# if isinstance(obj, list):
# for object in obj:
# self.session.delete(object)
# return

# TODO: Await needed?
await self.session.delete(obj)
async with SessionLocal() as session, session.begin():
await session.delete(obj)

async def refresh(self, obj: Type[ModelT]) -> None:
"""Refresh the state of an object from the database.
Expand All @@ -330,7 +325,8 @@ async def refresh(self, obj: Type[ModelT]) -> None:
Raises:
None
"""
return await self.session.refresh(obj)
async with SessionLocal() as session:
return await session.refresh(obj)

async def refresh_all(self, object_list: List[ModelT]) -> None:
"""Refresh the state of multiple objects from the database.
Expand All @@ -344,5 +340,6 @@ async def refresh_all(self, object_list: List[ModelT]) -> None:
Raises:
None
"""
for obj in object_list:
await self.session.refresh(obj)
async with SessionLocal() as session:
for obj in object_list:
await session.refresh(obj)
8 changes: 3 additions & 5 deletions app/routers/api/scheduled_transactions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi import Depends, status

from app import schemas
from app import session_transaction_manager as tm
from app.models import User
from app.routers.api.users import current_active_verified_user
from app.services.scheduled_transactions import ScheduledTransactionService
Expand Down Expand Up @@ -89,10 +88,9 @@ async def api_create_scheduled_transaction(
HTTPException: If the scheduled transaction is not created.
"""

async with tm.transaction():
return await service.create_scheduled_transaction(
current_user, transaction_information
)
return await service.create_scheduled_transaction(
current_user, transaction_information
)


@router.post("/{transaction_id}", response_model=ResponseModel)
Expand Down
20 changes: 7 additions & 13 deletions app/routers/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fastapi import Depends, status

from app import schemas
from app import session_transaction_manager as tm
from app.models import User
from app.routers.api.users import current_active_verified_user
from app.services.transactions import TransactionService
Expand Down Expand Up @@ -90,8 +89,7 @@ async def api_create_transaction(

transaction_data = schemas.TransactionData(**transaction_information.model_dump())

async with tm.transaction():
return await service.create_transaction(current_user, transaction_data)
return await service.create_transaction(current_user, transaction_data)


@router.post("/{transaction_id}", response_model=schemas.TransactionResponse)
Expand All @@ -116,14 +114,11 @@ async def api_update_transaction(
HTTPException: If the transaction is not found.
"""

async with tm.transaction():
transaction = await service.update_transaction(
current_user,
transaction_id,
transaction_information,
)

return transaction
return await service.update_transaction(
current_user,
transaction_id,
transaction_information,
)


@router.delete("/{transaction_id}", status_code=status.HTTP_204_NO_CONTENT)
Expand All @@ -146,5 +141,4 @@ async def api_delete_transaction(
HTTPException: If the transaction is not found.
"""

async with tm.transaction():
return await service.delete_transaction(current_user, transaction_id)
return await service.delete_transaction(current_user, transaction_id)
Loading

0 comments on commit c6b1f14

Please sign in to comment.