From bee577f5a37400a12c69562879d6159202861ed1 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 21 Jan 2025 15:16:29 +0000 Subject: [PATCH] feat: wip --- advanced_alchemy/filters.py | 95 +++++++++++++++++++++++++++---- tests/integration/test_filters.py | 13 +++-- 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/advanced_alchemy/filters.py b/advanced_alchemy/filters.py index 5f9187ff..6c086412 100644 --- a/advanced_alchemy/filters.py +++ b/advanced_alchemy/filters.py @@ -41,7 +41,20 @@ from operator import attrgetter from typing import TYPE_CHECKING, Any, Generic, Literal, cast -from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, exists, or_, select, text +from sqlalchemy import ( + BinaryExpression, + Delete, + Select, + Update, + and_, + any_, + exists, + false, + not_, + or_, + select, + text, +) from typing_extensions import TypeVar if TYPE_CHECKING: @@ -615,6 +628,8 @@ class ExistsFilter(StatementFilter): ) """ + field_name: str + """Name of model attribute to search on.""" values: list[ColumnElement[bool]] """List of SQLAlchemy column expressions to use in the EXISTS clause.""" operator: Literal["and", "or"] = "and" @@ -644,6 +659,35 @@ def _or(self) -> Callable[..., ColumnElement[bool]]: """ return or_ + def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]: + """Generate the EXISTS clause for the statement. + + Args: + model: The SQLAlchemy model class + + Returns: + ColumnElement[bool]: EXISTS clause + """ + field = self._get_instrumented_attr(model, self.field_name) + + # Get the underlying column name of the field + field_column = getattr(field, "comparator", None) + if not field_column: + return false() # Handle cases where the field might not be directly comparable, ie. relations + field_column_name = field_column.key + + # Construct a subquery using select() + subquery = select(field).where( + *( + [getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)] + if self.operator == "and" + else [getattr(model, field_column_name) == getattr(model, field_column_name), self._or(*self.values)] + ) + ) + + # Use the subquery in the exists() clause + return exists(subquery) + def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: """Apply EXISTS condition to the statement. @@ -659,10 +703,7 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> """ if not self.values: return statement - - if self.operator == "and": - exists_clause = select(model).where(self._and(*self.values)).exists() - exists_clause = select(model).where(self._or(*self.values)).exists() + exists_clause = self.get_exists_clause(model) return cast("StatementTypeT", statement.where(exists_clause)) @@ -685,6 +726,7 @@ class NotExistsFilter(StatementFilter): from advanced_alchemy.filters import NotExistsFilter filter = NotExistsFilter( + field_name="User.is_active", values=[User.email.like("%@example.com%")], ) statement = filter.append_to_statement( @@ -694,13 +736,16 @@ class NotExistsFilter(StatementFilter): Using OR conditions:: filter = NotExistsFilter( + field_name="User.role", values=[User.role == "admin", User.role == "owner"], operator="or", ) """ + field_name: str + """Name of model attribute to search on.""" values: list[ColumnElement[bool]] - """List of SQLAlchemy column expressions to use in the EXISTS clause.""" + """List of SQLAlchemy column expressions to use in the NOT EXISTS clause.""" operator: Literal["and", "or"] = "and" """If "and", combines conditions with AND, otherwise uses OR.""" @@ -728,6 +773,37 @@ def _or(self) -> Callable[..., ColumnElement[bool]]: """ return or_ + def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]: + """Generate the NOT EXISTS clause for the statement. + + Args: + model: The SQLAlchemy model class + + Returns: + ColumnElement[bool]: NOT EXISTS clause + """ + field = self._get_instrumented_attr(model, self.field_name) + + # Get the underlying column name of the field + field_column = getattr(field, "comparator", None) + if not field_column: + return false() # Handle cases where the field might not be directly comparable, ie. relations + field_column_name = field_column.key + + # Construct a subquery using select() + subquery = select(field).where( + *( + [getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)] + if self.operator == "and" + else [ + getattr(model, field_column_name) == getattr(model, field_column_name), + self._or(*self.values), + ] + ) + ) + # Use the subquery in the exists() clause and negate it with not_() + return not_(exists(subquery)) + def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: """Apply NOT EXISTS condition to the statement. @@ -743,8 +819,5 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> """ if not self.values: return statement - - if self.operator == "and": - exists_clause = select(model).where(self._and(*self.values)).exists() - exists_clause = select(model).where(self._or(*self.values)).exists() - return cast("StatementTypeT", statement.where(~exists_clause)) + exists_clause = self.get_exists_clause(model) + return cast("StatementTypeT", statement.where(exists_clause)) diff --git a/tests/integration/test_filters.py b/tests/integration/test_filters.py index 83de3640..a8944c24 100644 --- a/tests/integration/test_filters.py +++ b/tests/integration/test_filters.py @@ -83,32 +83,35 @@ def test_not_in_collection_filter(db_session: Session) -> None: def test_exists_filter_basic(db_session: Session) -> None: - exists_filter_1 = ExistsFilter(values=[Movie.genre == "Action"]) + exists_filter_1 = ExistsFilter(field_name="genre", values=[Movie.genre == "Action"]) statement = exists_filter_1.append_to_statement(select(Movie), Movie) results = db_session.execute(statement).scalars().all() assert len(results) == 1 - exists_filter_2 = ExistsFilter(values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")]) + exists_filter_2 = ExistsFilter( + field_name="genre", values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")] + ) statement = exists_filter_2.append_to_statement(select(Movie), Movie) results = db_session.execute(statement).scalars().all() assert len(results) == 2 def test_exists_filter(db_session: Session) -> None: - exists_filter_1 = ExistsFilter(values=[Movie.title.startswith("The")]) + exists_filter_1 = ExistsFilter(field_name="title", values=[Movie.title.startswith("The")]) statement = exists_filter_1.append_to_statement(select(Movie), Movie) results = db_session.execute(statement).scalars().all() assert len(results) == 3 exists_filter_2 = ExistsFilter( + field_name="title", values=[Movie.title.startswith("Shawshank Redemption"), Movie.title.startswith("The")], - operator="and", ) statement = exists_filter_2.append_to_statement(select(Movie), Movie) results = db_session.execute(statement).scalars().all() assert len(results) == 0 exists_filter_3 = ExistsFilter( + field_name="title", values=[Movie.title.startswith("The"), Movie.title.startswith("Shawshank")], operator="or", ) @@ -118,7 +121,7 @@ def test_exists_filter(db_session: Session) -> None: def test_not_exists_filter(db_session: Session) -> None: - not_exists_filter = NotExistsFilter(values=[Movie.title.like("%Hangover%")]) + not_exists_filter = NotExistsFilter(field_name="title", values=[Movie.title.like("%Hangover%")]) statement = not_exists_filter.append_to_statement(select(Movie), Movie) results = db_session.execute(statement).scalars().all() assert len(results) == 2