Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case insensitive search #925

Merged
merged 13 commits into from
Feb 20, 2025
89 changes: 83 additions & 6 deletions src/dso_api/dynamic_api/filters/lookups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Additional ORM lookups to implement the various DSO filter operators."""

from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.db.models import expressions, lookups

Expand Down Expand Up @@ -45,17 +46,32 @@ def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection, lhs=lhs) # (field, [])
rhs, rhs_params = self.process_rhs(compiler, connection) # ("%s", [value])

field_type = self.lhs.output_field.get_internal_type()
if lhs_nullable and rhs is not None:
# Allow field__not=value to return NULL fields too.
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)

if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return (
f"({lhs}) IS NULL OR UPPER({lhs}) != UPPER({rhs}))",
list(lhs_params + lhs_params)
+ [rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params],
)
else:
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)

elif rhs_params and rhs_params[0] is None:
# Allow field__not=None to work.
return f"{lhs} IS NOT NULL", lhs_params
else:
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params
if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return f"UPPER({lhs}) != UPPER({rhs})", list(lhs_params) + [
rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params
]
else:
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params


@models.CharField.register_lookup
Expand All @@ -71,9 +87,13 @@ def as_sql(self, compiler, connection):
# rhs = %s
# lhs_params = []
# lhs_params = ["prep-value"]

lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
return f"{lhs} LIKE {rhs}", lhs_params + rhs_params
if self.lhs.field.primary_key:
return f"{lhs} LIKE {rhs}", lhs_params + rhs_params
else:
return f"UPPER({lhs}) LIKE {rhs}", lhs_params + [rhs.upper() for rhs in rhs_params]

def get_db_prep_lookup(self, value, connection):
"""Apply the wildcard logic to the right-hand-side value"""
Expand All @@ -92,3 +112,60 @@ def _sql_wildcards(value: str) -> str:
.replace("*", "%")
.replace("?", "_")
)


@models.ForeignKey.register_lookup
class CaseInsensitiveExact(lookups.Lookup):
lookup_name = "iexact"

def as_sql(self, compiler, connection):
"""Generate the required SQL."""
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
if not self.lhs.field.primary_key:
return f"UPPER({lhs}) = UPPER({rhs})", lhs_params + rhs_params
else:
return f"{lhs} = {rhs}", lhs_params + rhs_params


@ArrayField.register_lookup
class ArrayContainsCaseInsensitive(lookups.Lookup):
"""
Override the default "contains" lookup for ArrayFields such that it does a case-insensitive
search. It also supports passing a comma separated string of values (or an iterable)
and returns matches only when the array field contains ALL of these values.
"""

lookup_name = "contains"

def as_sql(self, compiler, connection):
# Process the left-hand side expression (the array column)
lhs, lhs_params = self.process_lhs(compiler, connection)

# If the lookup value is a comma-separated string, split it;
# otherwise assume it is an iterable of values or a single value.
if isinstance(self.rhs, str):
# Split value on commas and filter out any empty strings.
values = [val.strip() for val in self.rhs.split(",") if val.strip()]
else:
try:
iter(self.rhs)
except TypeError:
values = [self.rhs]
else:
values = list(self.rhs)

# Convert each search value to uppercase for case-insensitive matching.
values = [v.upper() for v in values]

# Transform the values in the array column to uppercase; this is done by unnesting the
# array, applying UPPER() to each element, and reconstructing an array.
lhs_sql = f"(ARRAY(SELECT UPPER(x) FROM unnest({lhs}) AS x))" # noqa: S608

# Build a comma-separated set of placeholders for each search value.
placeholders = ", ".join(["%s"] * len(values))

# The resulting SQL uses the array "contains" operator @> to ensure that all provided
# values are present (case-insensitively) in the array field.
sql = f"{lhs_sql} @> ARRAY[{placeholders}]"
return sql, lhs_params + values
16 changes: 13 additions & 3 deletions src/dso_api/dynamic_api/filters/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,9 @@ def _compile_filter(
"""Build the Q() object for a single filter"""
parts = _parse_filter_path(filter_input.path, table_schema, self.user_scopes)
orm_path = _to_orm_path(parts)

value = self._translate_raw_value(filter_input, parts[-1])
lookup = self._translate_lookup(filter_input, parts[-1], value)
q_path = f"{orm_path}__{lookup}"

if filter_input.lookup == "not":
# for [not] lookup: field != 1 AND field != 2
q_object = reduce(operator.and_, (Q(**{q_path: v}) for v in value))
Expand Down Expand Up @@ -323,11 +321,23 @@ def _translate_lookup(
}
) from None

# Handle case-insensitive exact matches for string fields only,
# but not for relations or formatted fields
if filter_part.field.format == "date-time" and not isinstance(value, datetime):
# When something different then a full datetime is given, only compare dates.
# Otherwise, the "lte" comparison happens against 00:00:00.000 of that date,
# instead of anything that includes that day itself.
lookup = f"date__{lookup or 'exact'}"
return f"date__{lookup or 'exact'}"

# Only apply iexact for direct string field lookups (not through relations)
if (
not lookup
and filter_part.field.type == "string"
and filter_part.field.format not in ["date-time", "time", "date"]
and not filter_part.field.is_relation
and not filter_part.field.is_primary
):
return "iexact"

return lookup or "exact"

Expand Down
8 changes: 6 additions & 2 deletions src/tests/test_dynamic_api/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def test_like_filter_sql(self, django_assert_num_queries):
# using str(qs.query) doesn't apply database-level escaping,
# so running the query instead to get the actual executed query.
list(Dataset.objects.filter(name__like="foo*bar?"))

sql = context.captured_queries[0]["sql"]
assert r"""."name" LIKE 'foo%bar_'""" in sql
assert r"""."name") LIKE 'FOO%BAR_'""" in sql


def test_sql_wildcards():
Expand Down Expand Up @@ -96,6 +95,11 @@ def movie2(self, movies_model, movies_category):
("url[in]=foobar,http://example.com/someurl", {"movie2"}),
("url[like]=http:*", {"movie2"}),
("url[isnull]=true", {"movie1"}),
# Case insensitive match
("name=movie1", {"movie1"}),
("name=Movie1", {"movie1"}),
("name[like]=movie1", {"movie1"}),
("name[like]=Movie1", {"movie1"}),
],
)
def test_filter_logic(self, movies_model, movie1, movie2, query, expect):
Expand Down
Loading