Skip to content

Commit

Permalink
feat: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Jan 21, 2025
1 parent 3e98e43 commit bee577f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 16 deletions.
95 changes: 84 additions & 11 deletions advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,20 @@
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, cast

from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, exists, or_, select, text
from sqlalchemy import (
BinaryExpression,
Delete,
Select,
Update,
and_,
any_,
exists,
false,
not_,
or_,
select,
text,
)
from typing_extensions import TypeVar

if TYPE_CHECKING:
Expand Down Expand Up @@ -615,6 +628,8 @@ class ExistsFilter(StatementFilter):
)
"""

field_name: str
"""Name of model attribute to search on."""
values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
operator: Literal["and", "or"] = "and"
Expand Down Expand Up @@ -644,6 +659,35 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
"""
return or_

def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
"""Generate the EXISTS clause for the statement.
Args:
model: The SQLAlchemy model class
Returns:
ColumnElement[bool]: EXISTS clause
"""
field = self._get_instrumented_attr(model, self.field_name)

# Get the underlying column name of the field
field_column = getattr(field, "comparator", None)
if not field_column:
return false() # Handle cases where the field might not be directly comparable, ie. relations
field_column_name = field_column.key

# Construct a subquery using select()
subquery = select(field).where(
*(
[getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)]
if self.operator == "and"
else [getattr(model, field_column_name) == getattr(model, field_column_name), self._or(*self.values)]
)
)

# Use the subquery in the exists() clause
return exists(subquery)

def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply EXISTS condition to the statement.
Expand All @@ -659,10 +703,7 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
"""
if not self.values:
return statement

if self.operator == "and":
exists_clause = select(model).where(self._and(*self.values)).exists()
exists_clause = select(model).where(self._or(*self.values)).exists()
exists_clause = self.get_exists_clause(model)
return cast("StatementTypeT", statement.where(exists_clause))


Expand All @@ -685,6 +726,7 @@ class NotExistsFilter(StatementFilter):
from advanced_alchemy.filters import NotExistsFilter
filter = NotExistsFilter(
field_name="User.is_active",
values=[User.email.like("%@example.com%")],
)
statement = filter.append_to_statement(
Expand All @@ -694,13 +736,16 @@ class NotExistsFilter(StatementFilter):
Using OR conditions::
filter = NotExistsFilter(
field_name="User.role",
values=[User.role == "admin", User.role == "owner"],
operator="or",
)
"""

field_name: str
"""Name of model attribute to search on."""
values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
"""List of SQLAlchemy column expressions to use in the NOT EXISTS clause."""
operator: Literal["and", "or"] = "and"
"""If "and", combines conditions with AND, otherwise uses OR."""

Expand Down Expand Up @@ -728,6 +773,37 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
"""
return or_

def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
"""Generate the NOT EXISTS clause for the statement.
Args:
model: The SQLAlchemy model class
Returns:
ColumnElement[bool]: NOT EXISTS clause
"""
field = self._get_instrumented_attr(model, self.field_name)

# Get the underlying column name of the field
field_column = getattr(field, "comparator", None)
if not field_column:
return false() # Handle cases where the field might not be directly comparable, ie. relations
field_column_name = field_column.key

# Construct a subquery using select()
subquery = select(field).where(
*(
[getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)]
if self.operator == "and"
else [
getattr(model, field_column_name) == getattr(model, field_column_name),
self._or(*self.values),
]
)
)
# Use the subquery in the exists() clause and negate it with not_()
return not_(exists(subquery))

def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply NOT EXISTS condition to the statement.
Expand All @@ -743,8 +819,5 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
"""
if not self.values:
return statement

if self.operator == "and":
exists_clause = select(model).where(self._and(*self.values)).exists()
exists_clause = select(model).where(self._or(*self.values)).exists()
return cast("StatementTypeT", statement.where(~exists_clause))
exists_clause = self.get_exists_clause(model)
return cast("StatementTypeT", statement.where(exists_clause))
13 changes: 8 additions & 5 deletions tests/integration/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,35 @@ def test_not_in_collection_filter(db_session: Session) -> None:


def test_exists_filter_basic(db_session: Session) -> None:
exists_filter_1 = ExistsFilter(values=[Movie.genre == "Action"])
exists_filter_1 = ExistsFilter(field_name="genre", values=[Movie.genre == "Action"])
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 1

exists_filter_2 = ExistsFilter(values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")])
exists_filter_2 = ExistsFilter(
field_name="genre", values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")]
)
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_exists_filter(db_session: Session) -> None:
exists_filter_1 = ExistsFilter(values=[Movie.title.startswith("The")])
exists_filter_1 = ExistsFilter(field_name="title", values=[Movie.title.startswith("The")])
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 3

exists_filter_2 = ExistsFilter(
field_name="title",
values=[Movie.title.startswith("Shawshank Redemption"), Movie.title.startswith("The")],
operator="and",
)
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 0

exists_filter_3 = ExistsFilter(
field_name="title",
values=[Movie.title.startswith("The"), Movie.title.startswith("Shawshank")],
operator="or",
)
Expand All @@ -118,7 +121,7 @@ def test_exists_filter(db_session: Session) -> None:


def test_not_exists_filter(db_session: Session) -> None:
not_exists_filter = NotExistsFilter(values=[Movie.title.like("%Hangover%")])
not_exists_filter = NotExistsFilter(field_name="title", values=[Movie.title.like("%Hangover%")])
statement = not_exists_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2
Expand Down

0 comments on commit bee577f

Please sign in to comment.