From 4b2c3dff09cfb3fca504be7b5a55d85b7640e997 Mon Sep 17 00:00:00 2001 From: robbertuittenbroek Date: Fri, 28 Jun 2024 12:40:44 +0200 Subject: [PATCH] Add more demo objects to the demo suite --- BUILD.md | 7 ++ tad/core/config.py | 1 + tad/core/db.py | 65 ++++++++++---- ...468e4_create_status_user_and_task_table.py | 13 +-- tests/constants.py | 16 ++++ tests/core/test_db.py | 84 ++++++++++++++----- 6 files changed, 138 insertions(+), 48 deletions(-) diff --git a/BUILD.md b/BUILD.md index 5d5420684..5a0152368 100644 --- a/BUILD.md +++ b/BUILD.md @@ -23,6 +23,13 @@ When poetry is done installing all dependencies you can start using the tool. poetry run python -m uvicorn tad.main:app --log-level warning ``` +### Suggested development ENVIRONMENT settings +To use a demo environment during local development, you can use the following environment options. You can leave out the TRUNCATE_TABLES option if you wish to keep the state between runs. +```shell +export ENVIRONMENT=demo AUTO_CREATE_SCHEMA=true TRUNCATE_TABLES=true +``` + + ## Database We support most SQL database types. You can use the variable `APP_DATABASE_SCHEME` to change the database. The default scheme is sqlite. diff --git a/tad/core/config.py b/tad/core/config.py index 71b802014..4b1113b8d 100644 --- a/tad/core/config.py +++ b/tad/core/config.py @@ -33,6 +33,7 @@ class Settings(BaseSettings): DEBUG: bool = False AUTO_CREATE_SCHEMA: bool = False + TRUNCATE_TABLES: bool = False # todo(berry): create submodel for database settings APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite" diff --git a/tad/core/db.py b/tad/core/db.py index e07be8b13..68129038a 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -3,7 +3,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.pool import QueuePool, StaticPool -from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, create_engine, delete, select from tad.core.config import get_settings from tad.models import Status, Task, User @@ -43,21 +43,52 @@ def init_db(): with Session(get_engine()) as session: if get_settings().ENVIRONMENT == "demo": - logger.info("Creating demo data") + if get_settings().TRUNCATE_TABLES: + truncate_tables(session) - user = session.exec(select(User).where(User.name == "Robbert")).first() - if not user: - user = User(name="Robbert", avatar=None) - session.add(user) - - status = session.exec(select(Status).where(Status.name == "Todo")).first() - if not status: - status = Status(name="Todo", sort_order=1) - session.add(status) - - task = session.exec(select(Task).where(Task.title == "First task")).first() - if not task: - task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) - session.add(task) - session.commit() + logger.info("Creating demo data") + add_demo_users(session, ["default user"]) + add_demo_statuses(session, ["todo", "review", "in_progress", "done"]) + todo_status = session.exec(select(Status).where(Status.name == "todo")).first() + if todo_status is not None: + add_demo_tasks(session, todo_status, 3) logger.info("Finished initializing database") + + +def truncate_tables(session: Session) -> None: + logger.info("Truncating tables") + session.exec(delete(Task)) # type: ignore + session.exec(delete(User)) # type: ignore + session.exec(delete(Status)) # type: ignore + + +def add_demo_users(session: Session, user_names: list[str]) -> None: + for user_name in user_names: + user = session.exec(select(User).where(User.name == user_name)).first() + if not user: + session.add(User(name=user_name, avatar=None)) + session.commit() + + +def add_demo_tasks(session: Session, status: Status, number_of_tasks: int) -> None: + for index in range(1, number_of_tasks + 1): + title = "Example task " + str(index) + task = session.exec(select(Task).where(Task.title == title)).first() + if not task: + session.add( + Task( + title=title, + description="Example description " + str(index), + sort_order=index, + status_id=status.id, + ) + ) + session.commit() + + +def add_demo_statuses(session: Session, statuses: list[str]) -> None: + for index, status_name in enumerate(statuses): + status = session.exec(select(Status).where(Status.name == status_name)).first() + if not status: + session.add(Status(name=status_name, sort_order=index + 1)) + session.commit() diff --git a/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py index d5912d550..ab57151a7 100644 --- a/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py +++ b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py @@ -21,7 +21,7 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - status = op.create_table( + op.create_table( "status", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), @@ -56,17 +56,6 @@ def upgrade() -> None: ) # ### end Alembic commands ### - # ### custom commands ### - op.bulk_insert( - status, - [ - {"name": "Todo", "sort_order": 1}, - {"name": "In Progress", "sort_order": 2}, - {"name": "Review", "sort_order": 3}, - {"name": "Done", "sort_order": 4}, - ], - ) - def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### diff --git a/tests/constants.py b/tests/constants.py index 69ef56d1b..6e6c786c1 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,3 +1,5 @@ +from sqlmodel import select +from sqlmodel.sql._expression_select_cls import SelectOfScalar from tad.models import Status, Task, User @@ -37,3 +39,17 @@ def default_task( user_id: int | None = None, ) -> Task: return Task(title=title, description=description, sort_order=sort_order, status_id=status_id, user_id=user_id) + + +def expected_selects_demo_suite() -> list[SelectOfScalar[User] | SelectOfScalar[Status] | SelectOfScalar[Task]]: + return [ + select(User).where(User.name == "Robbert"), + select(Status).where(Status.name == "todo"), + select(Status).where(Status.name == "in_progress"), + select(Status).where(Status.name == "review"), + select(Status).where(Status.name == "done"), + select(Status).where(Status.name == "done"), + select(Task).where(Task.title == "Test task 1"), + select(Task).where(Task.title == "Test task 2"), + select(Task).where(Task.title == "Test task 3"), + ] diff --git a/tests/core/test_db.py b/tests/core/test_db.py index dbc7a8a99..5dc82bde6 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -2,11 +2,14 @@ from unittest.mock import MagicMock import pytest -from sqlmodel import Session, select +from sqlmodel import Session, delete, select +from sqlmodel.sql._expression_select_cls import SelectOfScalar from tad.core.config import Settings from tad.core.db import check_db, init_db from tad.models import Status, Task, User +from tests.constants import expected_selects_demo_suite + logger = logging.getLogger(__name__) @@ -21,47 +24,90 @@ def test_check_database(): @pytest.mark.parametrize( - "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], - indirect=True, + ("patch_settings", "expected_selects"), + [({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}, expected_selects_demo_suite())], + indirect=["patch_settings"], ) -def test_init_database_none(patch_settings: Settings): +def test_init_database_none( + patch_settings: Settings, + expected_selects: list[SelectOfScalar[User] | SelectOfScalar[Status] | SelectOfScalar[Task]], +): org_exec = Session.exec Session.exec = MagicMock() Session.exec.return_value.first.return_value = None init_db() - expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), - ] + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected_selects[i]) == str(call_args.args[0]) + + Session.exec = org_exec + + +@pytest.mark.parametrize( + ("patch_settings", "expected_selects"), + [({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}, expected_selects_demo_suite())], + indirect=["patch_settings"], +) +def test_init_database_none_with_todo_status( + patch_settings: Settings, + expected_selects: list[SelectOfScalar[User] | SelectOfScalar[Status] | SelectOfScalar[Task]], +): + org_exec = Session.exec + Session.exec = MagicMock() + todo_status = Status(id=1, name="todo", sort_order=1) + Session.exec.return_value.first.side_effect = [None, None, None, None, None, todo_status, None, None, None] + + init_db() + + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected_selects[i]) == str(call_args.args[0]) + + Session.exec = org_exec + + +@pytest.mark.parametrize( + ("patch_settings", "expected_selects"), + [({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}, expected_selects_demo_suite())], + indirect=["patch_settings"], +) +def test_init_database( + patch_settings: Settings, + expected_selects: list[SelectOfScalar[User] | SelectOfScalar[Status] | SelectOfScalar[Task]], +): + org_exec = Session.exec + Session.exec = MagicMock() + + init_db() for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) + assert str(expected_selects[i]) == str(call_args.args[0]) Session.exec = org_exec @pytest.mark.parametrize( - "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], - indirect=True, + ("patch_settings", "expected_selects"), + [({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True, "TRUNCATE_TABLES": True}, expected_selects_demo_suite())], + indirect=["patch_settings"], ) -def test_init_database(patch_settings: Settings): +def test_truncate_database( + patch_settings: Settings, + expected_selects: list[SelectOfScalar[User] | SelectOfScalar[Status] | SelectOfScalar[Task]], +): org_exec = Session.exec Session.exec = MagicMock() init_db() expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), + delete(Task), + delete(User), + delete(Status), + *expected_selects, ] for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) + assert str(expected[i]) == str(call_args.args[0]) Session.exec = org_exec