diff --git a/src/dso_api/dynamic_api/filters/lookups.py b/src/dso_api/dynamic_api/filters/lookups.py index a37487d9f..e73d1f5cc 100644 --- a/src/dso_api/dynamic_api/filters/lookups.py +++ b/src/dso_api/dynamic_api/filters/lookups.py @@ -1,5 +1,6 @@ """Additional ORM lookups to implement the various DSO filter operators.""" +from django.contrib.postgres.fields import ArrayField from django.db import models from django.db.models import expressions, lookups @@ -45,17 +46,32 @@ def as_sql(self, compiler, connection): lhs, lhs_params = self.process_lhs(compiler, connection, lhs=lhs) # (field, []) rhs, rhs_params = self.process_rhs(compiler, connection) # ("%s", [value]) + field_type = self.lhs.output_field.get_internal_type() if lhs_nullable and rhs is not None: # Allow field__not=value to return NULL fields too. - return ( - f"({lhs} IS NULL OR {lhs} != {rhs})", - list(lhs_params + lhs_params) + rhs_params, - ) + + if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key: + return ( + f"({lhs}) IS NULL OR UPPER({lhs}) != UPPER({rhs}))", + list(lhs_params + lhs_params) + + [rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params], + ) + else: + return ( + f"({lhs} IS NULL OR {lhs} != {rhs})", + list(lhs_params + lhs_params) + rhs_params, + ) + elif rhs_params and rhs_params[0] is None: # Allow field__not=None to work. return f"{lhs} IS NOT NULL", lhs_params else: - return f"{lhs} != {rhs}", list(lhs_params) + rhs_params + if field_type in ["CharField", "TextField"] and not self.lhs.field.primary_key: + return f"UPPER({lhs}) != UPPER({rhs})", list(lhs_params) + [ + rhs.upper() if isinstance(rhs, str) else rhs for rhs in rhs_params + ] + else: + return f"{lhs} != {rhs}", list(lhs_params) + rhs_params @models.CharField.register_lookup @@ -71,9 +87,13 @@ def as_sql(self, compiler, connection): # rhs = %s # lhs_params = [] # lhs_params = ["prep-value"] + lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) - return f"{lhs} LIKE {rhs}", lhs_params + rhs_params + if self.lhs.field.primary_key: + return f"{lhs} LIKE {rhs}", lhs_params + rhs_params + else: + return f"UPPER({lhs}) LIKE {rhs}", lhs_params + [rhs.upper() for rhs in rhs_params] def get_db_prep_lookup(self, value, connection): """Apply the wildcard logic to the right-hand-side value""" @@ -92,3 +112,60 @@ def _sql_wildcards(value: str) -> str: .replace("*", "%") .replace("?", "_") ) + + +@models.ForeignKey.register_lookup +class CaseInsensitiveExact(lookups.Lookup): + lookup_name = "iexact" + + def as_sql(self, compiler, connection): + """Generate the required SQL.""" + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + if not self.lhs.field.primary_key: + return f"UPPER({lhs}) = UPPER({rhs})", lhs_params + rhs_params + else: + return f"{lhs} = {rhs}", lhs_params + rhs_params + + +@ArrayField.register_lookup +class ArrayContainsCaseInsensitive(lookups.Lookup): + """ + Override the default "contains" lookup for ArrayFields such that it does a case-insensitive + search. It also supports passing a comma separated string of values (or an iterable) + and returns matches only when the array field contains ALL of these values. + """ + + lookup_name = "contains" + + def as_sql(self, compiler, connection): + # Process the left-hand side expression (the array column) + lhs, lhs_params = self.process_lhs(compiler, connection) + + # If the lookup value is a comma-separated string, split it; + # otherwise assume it is an iterable of values or a single value. + if isinstance(self.rhs, str): + # Split value on commas and filter out any empty strings. + values = [val.strip() for val in self.rhs.split(",") if val.strip()] + else: + try: + iter(self.rhs) + except TypeError: + values = [self.rhs] + else: + values = list(self.rhs) + + # Convert each search value to uppercase for case-insensitive matching. + values = [v.upper() for v in values] + + # Transform the values in the array column to uppercase; this is done by unnesting the + # array, applying UPPER() to each element, and reconstructing an array. + lhs_sql = f"(ARRAY(SELECT UPPER(x) FROM unnest({lhs}) AS x))" # noqa: S608 + + # Build a comma-separated set of placeholders for each search value. + placeholders = ", ".join(["%s"] * len(values)) + + # The resulting SQL uses the array "contains" operator @> to ensure that all provided + # values are present (case-insensitively) in the array field. + sql = f"{lhs_sql} @> ARRAY[{placeholders}]" + return sql, lhs_params + values diff --git a/src/dso_api/dynamic_api/filters/parser.py b/src/dso_api/dynamic_api/filters/parser.py index 8a801512e..1a1df2d31 100644 --- a/src/dso_api/dynamic_api/filters/parser.py +++ b/src/dso_api/dynamic_api/filters/parser.py @@ -251,11 +251,9 @@ def _compile_filter( """Build the Q() object for a single filter""" parts = _parse_filter_path(filter_input.path, table_schema, self.user_scopes) orm_path = _to_orm_path(parts) - value = self._translate_raw_value(filter_input, parts[-1]) lookup = self._translate_lookup(filter_input, parts[-1], value) q_path = f"{orm_path}__{lookup}" - if filter_input.lookup == "not": # for [not] lookup: field != 1 AND field != 2 q_object = reduce(operator.and_, (Q(**{q_path: v}) for v in value)) @@ -323,11 +321,23 @@ def _translate_lookup( } ) from None + # Handle case-insensitive exact matches for string fields only, + # but not for relations or formatted fields if filter_part.field.format == "date-time" and not isinstance(value, datetime): # When something different then a full datetime is given, only compare dates. # Otherwise, the "lte" comparison happens against 00:00:00.000 of that date, # instead of anything that includes that day itself. - lookup = f"date__{lookup or 'exact'}" + return f"date__{lookup or 'exact'}" + + # Only apply iexact for direct string field lookups (not through relations) + if ( + not lookup + and filter_part.field.type == "string" + and filter_part.field.format not in ["date-time", "time", "date"] + and not filter_part.field.is_relation + and not filter_part.field.is_primary + ): + return "iexact" return lookup or "exact" diff --git a/src/tests/test_dynamic_api/test_filters.py b/src/tests/test_dynamic_api/test_filters.py index 5a2ef57f6..987c96518 100644 --- a/src/tests/test_dynamic_api/test_filters.py +++ b/src/tests/test_dynamic_api/test_filters.py @@ -24,9 +24,8 @@ def test_like_filter_sql(self, django_assert_num_queries): # using str(qs.query) doesn't apply database-level escaping, # so running the query instead to get the actual executed query. list(Dataset.objects.filter(name__like="foo*bar?")) - sql = context.captured_queries[0]["sql"] - assert r"""."name" LIKE 'foo%bar_'""" in sql + assert r"""."name") LIKE 'FOO%BAR_'""" in sql def test_sql_wildcards(): @@ -96,6 +95,11 @@ def movie2(self, movies_model, movies_category): ("url[in]=foobar,http://example.com/someurl", {"movie2"}), ("url[like]=http:*", {"movie2"}), ("url[isnull]=true", {"movie1"}), + # Case insensitive match + ("name=movie1", {"movie1"}), + ("name=Movie1", {"movie1"}), + ("name[like]=movie1", {"movie1"}), + ("name[like]=Movie1", {"movie1"}), ], ) def test_filter_logic(self, movies_model, movie1, movie2, query, expect):