Skip to content

Commit

Permalink
Bdaniels/rollback caseinsensitive search (#956)
Browse files Browse the repository at this point in the history
* update not operator sql query

* rollback case insensitive search

* fixed test
  • Loading branch information
barrydaniels-nl authored Feb 24, 2025
1 parent 0cb97bd commit a7915b1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 99 deletions.
87 changes: 7 additions & 80 deletions src/dso_api/dynamic_api/filters/lookups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Additional ORM lookups to implement the various DSO filter operators."""

from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.db.models import expressions, lookups

Expand Down Expand Up @@ -46,31 +45,18 @@ def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection, lhs=lhs) # (field, [])
rhs, rhs_params = self.process_rhs(compiler, connection) # ("%s", [value])

field_type = self.lhs.output_field.get_internal_type()
if lhs_nullable and rhs is not None:
# Allow field__not=value to return NULL fields too.
if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return (
f"({lhs} IS NULL OR UPPER({lhs}) != UPPER({rhs}))",
list(lhs_params + lhs_params)
+ [rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params],
)
else:
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)

elif rhs_params and rhs_params[0] is None:
# Allow field__not=None to work.
return f"{lhs} IS NOT NULL", lhs_params
else:
if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return f"UPPER({lhs}) != UPPER({rhs})", list(lhs_params) + [
rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params
]
else:
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params


@models.CharField.register_lookup
Expand All @@ -89,10 +75,8 @@ def as_sql(self, compiler, connection):

lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
if self.lhs.field.primary_key:
return f"{lhs} LIKE {rhs}", lhs_params + rhs_params
else:
return f"UPPER({lhs}) LIKE {rhs}", lhs_params + [rhs.upper() for rhs in rhs_params]

return f"{lhs} LIKE {rhs}", lhs_params + rhs_params

def get_db_prep_lookup(self, value, connection):
"""Apply the wildcard logic to the right-hand-side value"""
Expand All @@ -111,60 +95,3 @@ def _sql_wildcards(value: str) -> str:
.replace("*", "%")
.replace("?", "_")
)


@models.ForeignKey.register_lookup
class CaseInsensitiveExact(lookups.Lookup):
lookup_name = "iexact"

def as_sql(self, compiler, connection):
"""Generate the required SQL."""
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
if not self.lhs.field.primary_key:
return f"UPPER({lhs}) = UPPER({rhs})", lhs_params + rhs_params
else:
return f"{lhs} = {rhs}", lhs_params + rhs_params


@ArrayField.register_lookup
class ArrayContainsCaseInsensitive(lookups.Lookup):
"""
Override the default "contains" lookup for ArrayFields such that it does a case-insensitive
search. It also supports passing a comma separated string of values (or an iterable)
and returns matches only when the array field contains ALL of these values.
"""

lookup_name = "contains"

def as_sql(self, compiler, connection):
# Process the left-hand side expression (the array column)
lhs, lhs_params = self.process_lhs(compiler, connection)

# If the lookup value is a comma-separated string, split it;
# otherwise assume it is an iterable of values or a single value.
if isinstance(self.rhs, str):
# Split value on commas and filter out any empty strings.
values = [val.strip() for val in self.rhs.split(",") if val.strip()]
else:
try:
iter(self.rhs)
except TypeError:
values = [self.rhs]
else:
values = list(self.rhs)

# Convert each search value to uppercase for case-insensitive matching.
values = [v.upper() for v in values]

# Transform the values in the array column to uppercase; this is done by unnesting the
# array, applying UPPER() to each element, and reconstructing an array.
lhs_sql = f"(ARRAY(SELECT UPPER(x) FROM unnest({lhs}) AS x))" # noqa: S608

# Build a comma-separated set of placeholders for each search value.
placeholders = ", ".join(["%s"] * len(values))

# The resulting SQL uses the array "contains" operator @> to ensure that all provided
# values are present (case-insensitively) in the array field.
sql = f"{lhs_sql} @> ARRAY[{placeholders}]"
return sql, lhs_params + values
14 changes: 1 addition & 13 deletions src/dso_api/dynamic_api/filters/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,11 @@ def _translate_lookup(
}
) from None

# Handle case-insensitive exact matches for string fields only,
# but not for relations or formatted fields
if filter_part.field.format == "date-time" and not isinstance(value, datetime):
# When something different then a full datetime is given, only compare dates.
# Otherwise, the "lte" comparison happens against 00:00:00.000 of that date,
# instead of anything that includes that day itself.
return f"date__{lookup or 'exact'}"

# Only apply iexact for direct string field lookups (not through relations)
if (
not lookup
and filter_part.field.type == "string"
and filter_part.field.format not in ["date-time", "time", "date"]
and not filter_part.field.is_relation
and not filter_part.field.is_primary
):
return "iexact"
lookup = f"date__{lookup or 'exact'}"

return lookup or "exact"

Expand Down
8 changes: 2 additions & 6 deletions src/tests/test_dynamic_api/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def test_like_filter_sql(self, django_assert_num_queries):
# using str(qs.query) doesn't apply database-level escaping,
# so running the query instead to get the actual executed query.
list(Dataset.objects.filter(name__like="foo*bar?"))

sql = context.captured_queries[0]["sql"]
assert r"""."name") LIKE 'FOO%BAR_'""" in sql
assert r"""."name" LIKE 'foo%bar_'""" in sql


def test_sql_wildcards():
Expand Down Expand Up @@ -95,11 +96,6 @@ def movie2(self, movies_model, movies_category):
("url[in]=foobar,http://example.com/someurl", {"movie2"}),
("url[like]=http:*", {"movie2"}),
("url[isnull]=true", {"movie1"}),
# Case insensitive match
("name=movie1", {"movie1"}),
("name=Movie1", {"movie1"}),
("name[like]=movie1", {"movie1"}),
("name[like]=Movie1", {"movie1"}),
],
)
def test_filter_logic(self, movies_model, movie1, movie2, query, expect):
Expand Down

0 comments on commit a7915b1

Please sign in to comment.