Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented case insensitive search #958

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 80 additions & 7 deletions src/dso_api/dynamic_api/filters/lookups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Additional ORM lookups to implement the various DSO filter operators."""

from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.db.models import expressions, lookups

Expand Down Expand Up @@ -45,18 +46,30 @@ def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection, lhs=lhs) # (field, [])
rhs, rhs_params = self.process_rhs(compiler, connection) # ("%s", [value])

field_type = self.lhs.output_field.get_internal_type()
if lhs_nullable and rhs is not None:
# Allow field__not=value to return NULL fields too.
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)

if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return (
f"({lhs}) IS NULL OR UPPER({lhs}) != UPPER({rhs}))",
list(lhs_params + lhs_params)
+ [rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params],
)
else:
return (
f"({lhs} IS NULL OR {lhs} != {rhs})",
list(lhs_params + lhs_params) + rhs_params,
)
elif rhs_params and rhs_params[0] is None:
# Allow field__not=None to work.
return f"{lhs} IS NOT NULL", lhs_params
else:
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params
if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key:
return f"UPPER({lhs}) != UPPER({rhs})", list(lhs_params) + [
rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params
]
else:
return f"{lhs} != {rhs}", list(lhs_params) + rhs_params


@models.CharField.register_lookup
Expand All @@ -76,7 +89,10 @@ def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)

return f"{lhs} LIKE {rhs}", lhs_params + rhs_params
if self.lhs.field.primary_key:
return f"{lhs} LIKE {rhs}", lhs_params + rhs_params
else:
return f"UPPER({lhs}) LIKE {rhs}", lhs_params + [rhs.upper() for rhs in rhs_params]

def get_db_prep_lookup(self, value, connection):
"""Apply the wildcard logic to the right-hand-side value"""
Expand All @@ -95,3 +111,60 @@ def _sql_wildcards(value: str) -> str:
.replace("*", "%")
.replace("?", "_")
)


@models.ForeignKey.register_lookup
class CaseInsensitiveExact(lookups.Lookup):
lookup_name = "iexact"

def as_sql(self, compiler, connection):
"""Generate the required SQL."""
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
if not self.lhs.field.primary_key:
return f"UPPER({lhs}) = UPPER({rhs})", lhs_params + rhs_params
else:
return f"{lhs} = {rhs}", lhs_params + rhs_params


@ArrayField.register_lookup
class ArrayContainsCaseInsensitive(lookups.Lookup):
"""
Override the default "contains" lookup for ArrayFields such that it does a case-insensitive
search. It also supports passing a comma separated string of values (or an iterable)
and returns matches only when the array field contains ALL of these values.
"""

lookup_name = "contains"

def as_sql(self, compiler, connection):
# Process the left-hand side expression (the array column)
lhs, lhs_params = self.process_lhs(compiler, connection)

# If the lookup value is a comma-separated string, split it;
# otherwise assume it is an iterable of values or a single value.
if isinstance(self.rhs, str):
# Split value on commas and filter out any empty strings.
values = [val.strip() for val in self.rhs.split(",") if val.strip()]
else:
try:
iter(self.rhs)
except TypeError:
values = [self.rhs]
else:
values = list(self.rhs)

# Convert each search value to uppercase for case-insensitive matching.
values = [v.upper() for v in values]

# Transform the values in the array column to uppercase; this is done by unnesting the
# array, applying UPPER() to each element, and reconstructing an array.
lhs_sql = f"(ARRAY(SELECT UPPER(x) FROM unnest({lhs}) AS x))" # noqa: S608

# Build a comma-separated set of placeholders for each search value.
placeholders = ", ".join(["%s"] * len(values))

# The resulting SQL uses the array "contains" operator @> to ensure that all provided
# values are present (case-insensitively) in the array field.
sql = f"{lhs_sql} @> ARRAY[{placeholders}]"
return sql, lhs_params + values
15 changes: 13 additions & 2 deletions src/dso_api/dynamic_api/filters/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,23 @@ def _translate_lookup(
)
}
) from None

# Handle case-insensitive exact matches for string fields only,
# but not for relations or formatted fields
if filter_part.field.format == "date-time" and not isinstance(value, datetime):
# When something different then a full datetime is given, only compare dates.
# Otherwise, the "lte" comparison happens against 00:00:00.000 of that date,
# instead of anything that includes that day itself.
lookup = f"date__{lookup or 'exact'}"
return f"date__{lookup or 'exact'}"

# Only apply iexact for direct string field lookups (not through relations)
if (
not lookup
and filter_part.field.type == "string"
and filter_part.field.format not in ["date-time", "time", "date"]
and not filter_part.field.is_relation
and not filter_part.field.is_primary
):
return "iexact"

return lookup or "exact"

Expand Down
7 changes: 6 additions & 1 deletion src/tests/test_dynamic_api/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_like_filter_sql(self, django_assert_num_queries):
list(Dataset.objects.filter(name__like="foo*bar?"))

sql = context.captured_queries[0]["sql"]
assert r"""."name" LIKE 'foo%bar_'""" in sql
assert r"""."name") LIKE 'FOO%BAR_'""" in sql


def test_sql_wildcards():
Expand Down Expand Up @@ -96,6 +96,11 @@ def movie2(self, movies_model, movies_category):
("url[in]=foobar,http://example.com/someurl", {"movie2"}),
("url[like]=http:*", {"movie2"}),
("url[isnull]=true", {"movie1"}),
# Case insensitive match
("name=movie1", {"movie1"}),
("name=Movie1", {"movie1"}),
("name[like]=movie1", {"movie1"}),
("name[like]=Movie1", {"movie1"}),
],
)
def test_filter_logic(self, movies_model, movie1, movie2, query, expect):
Expand Down