Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix isolate db sessions #150

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions app/authentication/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from functools import lru_cache
from typing import AsyncGenerator

from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase

from app import models
from app.authentication.management import UserManager
from app.authentication.strategies import JWTAccessRefreshStrategy
from app.config import settings
from app.database import get_user_db
from app.database import SessionLocal


async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator[UserManager, None]:
async def get_user_manager() -> AsyncGenerator[UserManager, None]:
"""
Asynchronously yields a UserManager for the provided user.

Expand All @@ -26,7 +24,9 @@ async def get_user_manager(
UserInactive: If the user is inactive.
UserAlreadyVerified: If the user is already verified.
"""
yield UserManager(user_db)
async with SessionLocal() as session:
user_db = SQLAlchemyUserDatabase(session, models.User, models.OAuthAccount)
yield UserManager(user_db)


@lru_cache
Expand Down
90 changes: 3 additions & 87 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,6 @@
from typing import Optional

from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

from app.config import settings


class Database:
def __init__(self, url: str) -> None:
"""
Initialize the database connection.

Args:
url (str): The URL of the database.

Raises:
None
"""

self.url = url
self.engine = create_async_engine(url, future=True)
self._session: Optional[AsyncSession] = None

async def init(self):
"""
Initialize the database connection.

Raises:
None
"""

self._session = self.get_session()

def get_session(self) -> AsyncSession:
"""
Get the asynchronous session for the database.

Returns:
AsyncSession: The asynchronous session for the database.

Raises:
RuntimeError: If the engine has not been initialized.
"""

if self.engine is None:
raise RuntimeError("Engine has not been initialized")

session_factory = sessionmaker(
self.engine, expire_on_commit=False, class_=AsyncSession
)

return session_factory()

@property
def session(self) -> AsyncSession:
"""
Get the asynchronous session for the database.

Returns:
AsyncSession: The asynchronous session for the database.

Raises:
RuntimeError: If the database session has not been initialized.
"""

if self._session is None:
raise RuntimeError("Database session has not been initialized.")
return self._session


async def get_user_db():
"""Get the user database.

Args:
None

Returns:
SQLAlchemyUserDatabase: The user database.

Raises:
None
"""
from app.models import OAuthAccount, User # pylint: disable=import-outside-toplevel

yield SQLAlchemyUserDatabase(db.session, User, OAuthAccount)


db: Database = Database(settings.db_url)
engine = create_async_engine(settings.db_url, future=True)
SessionLocal = async_sessionmaker(expire_on_commit=False, bind=engine)
4 changes: 0 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from app.authentication.dependencies import get_strategy
from app.authentication.strategies import JWTAccessRefreshStrategy
from app.config import settings
from app.database import db
from app.exception_handler import (
bad_request_exception_handler,
entity_access_denied_exception_handler,
Expand Down Expand Up @@ -69,13 +68,10 @@ async def lifespan(_api_app: FastAPI):
"""

try:
await db.init()
add_jobs_to_scheduler(scheduler)
yield
finally:
scheduler.shutdown()
if db.session is not None:
await db.session.close()


app = FastAPI(lifespan=lifespan)
Expand Down
63 changes: 30 additions & 33 deletions app/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

from sqlalchemy import Select, exists, func, text
from sqlalchemy import update as sql_update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute

from app import models
from app.database import db
from app.database import SessionLocal
from app.exceptions.base_service_exception import EntityNotFoundException
from app.utils.enums import DatabaseFilterOperator, Frequency
from app.utils.fields import IdField
Expand All @@ -18,9 +17,6 @@

class Repository:

def __init__(self, session: Optional[AsyncSession] = None):
self.session = session or db.session

def _load_relationships(
self, query: Select, relationships: InstrumentedAttribute = None
) -> Select:
Expand Down Expand Up @@ -54,7 +50,8 @@ async def get_all(
"""
q = select(cls)
q = self._load_relationships(q, load_relationships_list)
result = await self.session.execute(q)
async with SessionLocal() as session:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider implementing a helper method for session management.

While the changes introduce more explicit session management, they also add unnecessary repetition. Consider introducing a helper method to encapsulate the session creation and execution pattern. This will maintain explicit control over sessions while reducing code duplication. Here's an example:

from contextlib import asynccontextmanager

class Repository:
    @asynccontextmanager
    async def session_scope(self):
        async with SessionLocal() as session:
            try:
                yield session
                await session.commit()
            except Exception:
                await session.rollback()
                raise

    async def get_all(self, cls: Type[ModelT], load_relationships_list: Optional[list[InstrumentedAttribute]] = None) -> list[ModelT]:
        q = select(cls)
        q = self._load_relationships(q, load_relationships_list)
        async with self.session_scope() as session:
            result = await session.execute(q)
        return result.unique().scalars().all()

    # Apply similar pattern to other methods

This approach:

  1. Centralizes session management logic
  2. Ensures consistent error handling and session cleanup
  3. Reduces code duplication
  4. Maintains explicit session control
  5. Improves readability and maintainability

Apply this pattern across all methods in the Repository class to significantly reduce complexity while preserving the benefits of explicit session management.

result = await session.execute(q)
return result.unique().scalars().all()

async def get(
Expand All @@ -77,7 +74,10 @@ async def get(
"""
q = select(cls).where(cls.id == instance_id)
q = self._load_relationships(q, load_relationships_list)
result = await self.session.execute(q)

async with SessionLocal() as session:
result = await session.execute(q)

model = result.scalars().first()

if model is None:
Expand Down Expand Up @@ -118,7 +118,8 @@ async def filter_by(
q = select(cls).where(condition).params(val=value)
q = self._load_relationships(q, load_relationships_list)

result = await self.session.execute(q)
async with SessionLocal() as session:
result = await session.execute(q)

return result.unique().scalars().all()

Expand Down Expand Up @@ -165,7 +166,8 @@ async def filter_by_multiple(
q = q.params(**params)
q = self._load_relationships(q, load_relationships_list)

result = await self.session.execute(q)
async with SessionLocal() as session:
result = await session.execute(q)

return result.scalars().unique().all()

Expand Down Expand Up @@ -220,7 +222,8 @@ def get_period_start_date(frequency_id: int) -> datetime:
transaction_exists_condition,
)

result = await self.session.execute(query)
async with SessionLocal() as session:
result = await session.execute(query)
return result.scalars().all()

async def get_transactions_from_period(
Expand Down Expand Up @@ -252,7 +255,8 @@ async def get_transactions_from_period(
.filter(wallet_id == transaction.wallet_id)
)

result = await self.session.execute(query)
async with SessionLocal() as session:
result = await session.execute(query)
return result.scalars().all()

async def save(self, obj: Union[ModelT, List[ModelT]]) -> None:
Expand All @@ -268,13 +272,14 @@ async def save(self, obj: Union[ModelT, List[ModelT]]) -> None:
None
"""

if isinstance(obj, list):
self.session.add_all(obj)
return
async with SessionLocal() as session, session.begin():
if isinstance(obj, list):
session.add_all(obj)
return

self.session.add(obj)
session.add(obj)

async def update(self, cls: Type[ModelT], instance_id: int, **kwargs) -> None:
async def update(self, cls: Type[ModelT], instance_id: int, **kwargs):
"""Update an instance of the specified model with the given ID.

Args:
Expand All @@ -294,29 +299,19 @@ async def update(self, cls: Type[ModelT], instance_id: int, **kwargs) -> None:
.values(**kwargs)
.execution_options(synchronize_session="fetch")
)
await self.session.execute(query)
async with SessionLocal() as session:
await session.execute(query)

async def delete(self, obj: Type[ModelT]) -> None:
"""Delete an object from the database.

Args:
obj: The object to delete.

Returns:
None

Raises:
None
"""

# TODO: Test this
# if isinstance(obj, list):
# for object in obj:
# self.session.delete(object)
# return

# TODO: Await needed?
await self.session.delete(obj)
async with SessionLocal() as session, session.begin():
await session.delete(obj)

async def refresh(self, obj: Type[ModelT]) -> None:
"""Refresh the state of an object from the database.
Expand All @@ -330,7 +325,8 @@ async def refresh(self, obj: Type[ModelT]) -> None:
Raises:
None
"""
return await self.session.refresh(obj)
async with SessionLocal() as session:
return await session.refresh(obj)

async def refresh_all(self, object_list: List[ModelT]) -> None:
"""Refresh the state of multiple objects from the database.
Expand All @@ -344,5 +340,6 @@ async def refresh_all(self, object_list: List[ModelT]) -> None:
Raises:
None
"""
for obj in object_list:
await self.session.refresh(obj)
async with SessionLocal() as session:
for obj in object_list:
await session.refresh(obj)
8 changes: 3 additions & 5 deletions app/routers/api/scheduled_transactions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi import Depends, status

from app import schemas
from app import session_transaction_manager as tm
from app.models import User
from app.routers.api.users import current_active_verified_user
from app.services.scheduled_transactions import ScheduledTransactionService
Expand Down Expand Up @@ -89,10 +88,9 @@ async def api_create_scheduled_transaction(
HTTPException: If the scheduled transaction is not created.
"""

async with tm.transaction():
return await service.create_scheduled_transaction(
current_user, transaction_information
)
return await service.create_scheduled_transaction(
current_user, transaction_information
)


@router.post("/{transaction_id}", response_model=ResponseModel)
Expand Down
20 changes: 7 additions & 13 deletions app/routers/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fastapi import Depends, status

from app import schemas
from app import session_transaction_manager as tm
from app.models import User
from app.routers.api.users import current_active_verified_user
from app.services.transactions import TransactionService
Expand Down Expand Up @@ -90,8 +89,7 @@ async def api_create_transaction(

transaction_data = schemas.TransactionData(**transaction_information.model_dump())

async with tm.transaction():
return await service.create_transaction(current_user, transaction_data)
return await service.create_transaction(current_user, transaction_data)


@router.post("/{transaction_id}", response_model=schemas.TransactionResponse)
Expand All @@ -116,14 +114,11 @@ async def api_update_transaction(
HTTPException: If the transaction is not found.
"""

async with tm.transaction():
transaction = await service.update_transaction(
current_user,
transaction_id,
transaction_information,
)

return transaction
return await service.update_transaction(
current_user,
transaction_id,
transaction_information,
)


@router.delete("/{transaction_id}", status_code=status.HTTP_204_NO_CONTENT)
Expand All @@ -146,5 +141,4 @@ async def api_delete_transaction(
HTTPException: If the transaction is not found.
"""

async with tm.transaction():
return await service.delete_transaction(current_user, transaction_id)
return await service.delete_transaction(current_user, transaction_id)
Loading