From 2f9403d2931d238bed8cf01aef2f1e1983b7ee2a Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 17 Oct 2024 03:36:43 +0200 Subject: [PATCH] caching queryset & cleanup cloning and initialization & special characters (#206) * Changes: - cache select_related expression and select expression - make QuerySet keyword only for the second plus argument - keyword arguments match now the function names, old names are deprecated - copy by calling init * Changes: - fix typings - make select_related variadic and deprecate the former call interface * Changes: - use new variadic select_related call in tests - update release notes * Changed: - copy cached select_related from temporary subqueries * fix get_raw * bump version * Changes: - allow special characters in model names - honor column_name in multi column fields * fix typings by switching type * update tests for special chars * Changes: - remove rendundant tests - fks can now point to columns with special chars --- docs/fields.md | 4 +- docs/release-notes.md | 11 +- edgy/__init__.py | 2 +- edgy/core/connection/registry.py | 4 +- edgy/core/db/fields/file_field.py | 20 +- edgy/core/db/fields/foreign_keys.py | 19 +- edgy/core/db/fields/many_to_many.py | 4 +- edgy/core/db/querysets/base.py | 522 ++++++++++-------- edgy/core/db/querysets/prefetch.py | 3 +- edgy/core/db/querysets/types.py | 10 +- edgy/core/db/relationships/relation.py | 2 +- tests/exclude_secrets/test_exclude_nested.py | 6 +- tests/fields/test_file_fields.py | 27 + tests/foreign_keys/test_foreignkey_special.py | 173 ++++++ .../test_many_to_many_field_special.py | 63 +++ .../test_many_to_many_field_special_auto.py | 61 ++ tests/models/test_select_related_mul.py | 4 +- tests/models/test_select_related_nested.py | 9 + tests/models/test_select_related_single.py | 2 +- tests/tenancy/test_select_related_multiple.py | 2 +- 20 files changed, 687 insertions(+), 261 deletions(-) create mode 100644 tests/foreign_keys/test_foreignkey_special.py create mode 100644 tests/foreign_keys/test_many_to_many_field_special.py create mode 100644 tests/foreign_keys/test_many_to_many_field_special_auto.py diff --git a/docs/fields.md b/docs/fields.md index 13121564..449c5b4a 100644 --- a/docs/fields.md +++ b/docs/fields.md @@ -20,7 +20,7 @@ Check the [primary_key](./models.md#restrictions-with-primary-keys) restrictions - `skip_reflection_type_check` - A boolean. Default False. Skip reflection column type check. - `unique` - A boolean. Determine if a unique constraint should be created for the field. Check the [unique_together](./models.md#unique-together) for more details. -- `column_name` - A string. Database name of the column (by default the same as the name) +- `column_name` - A string. Database name of the column (by default the same as the name). - `comment` - A comment to be added with the field in the SQL database. - `secret` - A special attribute that allows to call the [exclude_secrets](./queries/secrets.md#exclude-secrets) and avoid accidental leakage of sensitive data. @@ -502,6 +502,7 @@ from `edgy`. ``` * `relation_fn` - Optionally drop a function which returns a Relation for the reverse side. This will be used by the RelatedField (if it is created). Used by the ManyToMany field. * `reverse_path_fn` - Optionally drop a function which handles the traversal from the reverse side. Used by the ManyToMany field. +- `column_name` - A string. Base database name of the column (by default the same as the name). Useful for models with special characters in their name. !!! Note @@ -568,6 +569,7 @@ class MyModel(edgy.Model): * `related_name` - The name to use for the relation from the related object back to this one. * `through` - The model to be used for the relationship. Edgy generates the model by default if None is provided or `through` is an abstract model. +* `through_tablename` - Custom tablename for `through`. E.g. when special characters are used in model names. * `embed_through` - When traversing, embed the through object in this attribute. Otherwise it is not accessable from the result. if an empty string was provided, the old behaviour is used to query from the through model as base (default). if False, the base is transformed to the target and source model (full proxying). You cannot select the through model via path traversal anymore (except from the through model itself). diff --git a/docs/release-notes.md b/docs/release-notes.md index 9883ab31..ce9ca9c2 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -6,11 +6,12 @@ hide: # Release Notes -## 0.18.2 +## 0.19.0 ### Added - New `SET_DEFAULT`, and `PROTECT` to `on_delete` in the ForeignKey. +- New `through_tablename` parameter for ManyToMany. ### Removed @@ -21,6 +22,14 @@ hide: - Allow setting registry = False, for disabling retrieving the registry from parents. - Removed unecessary warning for ManyToMany. - Add warnings for problematic combinations in ForeignKey. +- Make QuerySet nearly keyword only and deprecate keywords not matching function names. +- Clone QuerySet via `__init__`. +- Make select_related variadic and deprecate former call taking a Sequence. +- Improved QuerySet caching. + +### Fixed + +- Multi-column fields honor now `column_name`. This allows special characters in model names. ## 0.18.1 diff --git a/edgy/__init__.py b/edgy/__init__.py index 1437c843..807345f5 100644 --- a/edgy/__init__.py +++ b/edgy/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.18.1" +__version__ = "0.19.0" from .cli.base import Migrate from .conf import settings diff --git a/edgy/core/connection/registry.py b/edgy/core/connection/registry.py index af56dc22..0269f080 100644 --- a/edgy/core/connection/registry.py +++ b/edgy/core/connection/registry.py @@ -419,7 +419,9 @@ async def create_all( async def drop_all(self, databases: Sequence[Union[str, None]] = (None,)) -> None: if self.db_schema: - await self.schema.drop_schema(self.db_schema, True, True, databases=databases) + await self.schema.drop_schema( + self.db_schema, cascade=True, if_exists=True, databases=databases + ) else: for database in databases: db = self.database if database is None else self.extra[database] diff --git a/edgy/core/db/fields/file_field.py b/edgy/core/db/fields/file_field.py index 04619f4c..68a5383a 100644 --- a/edgy/core/db/fields/file_field.py +++ b/edgy/core/db/fields/file_field.py @@ -38,6 +38,7 @@ class ConcreteFileField(BaseCompositeField): + column_name: str = "" multi_process_safe: bool = True field_file_class: type[FieldFile] _generate_name_fn: Optional[ @@ -164,15 +165,18 @@ def to_model( def get_columns(self, field_name: str) -> Sequence[sqlalchemy.Column]: model = ColumnDefinitionModel.model_validate(self, from_attributes=True) + column_name = self.column_name or field_name return [ sqlalchemy.Column( - field_name, - model.column_type, - **model.model_dump(by_alias=True, exclude_none=True), + key=field_name, + type_=model.column_type, + name=column_name, + **model.model_dump(by_alias=True, exclude_none=True, exclude={"column_name"}), ), sqlalchemy.Column( - f"{field_name}_storage", - sqlalchemy.String(length=20, collation=self.column_type.collation), + key=f"{field_name}_storage", + name=f"{column_name}_storage", + type_=sqlalchemy.String(length=20, collation=self.column_type.collation), default=self.storage.name, ), ] @@ -181,6 +185,7 @@ def get_embedded_fields( self, name: str, fields: dict[str, "BaseFieldType"] ) -> dict[str, "BaseFieldType"]: retdict: dict[str, Any] = {} + column_name = self.column_name or name # TODO: use embed_field if self.with_size: size_name = f"{name}_size" @@ -191,6 +196,7 @@ def get_embedded_fields( exclude=True, read_only=True, name=size_name, + column_name=f"{column_name}_size", owner=self.owner, ) if self.with_approval: @@ -200,7 +206,7 @@ def get_embedded_fields( null=False, default=False, exclude=True, - column_name=f"{name}_ok", + column_name=f"{column_name}_ok", name=approval_name, owner=self.owner, ) @@ -209,7 +215,7 @@ def get_embedded_fields( if metadata_name not in fields: retdict[metadata_name] = JSONField( null=False, - column_name=f"{name}_mname", + column_name=f"{column_name}_mname", name=metadata_name, owner=self.owner, default=dict, diff --git a/edgy/core/db/fields/foreign_keys.py b/edgy/core/db/fields/foreign_keys.py index 9a3130a4..491ddda7 100644 --- a/edgy/core/db/fields/foreign_keys.py +++ b/edgy/core/db/fields/foreign_keys.py @@ -36,6 +36,8 @@ class BaseForeignKeyField(BaseForeignKey): force_cascade_deletion_relation: bool = False relation_has_post_delete_callback: bool = False + # overwrite for sondercharacters + column_name: Optional[str] = None def __init__( self, @@ -277,6 +279,12 @@ def get_fk_field_name(self, name: str, fieldname: str) -> str: return name return f"{name}_{fieldname}" + def get_fk_column_name(self, name: str, fieldname: str) -> str: + name = self.column_name or name + if len(self.related_columns) == 1: + return name + return f"{name}_{fieldname}" + def from_fk_field_name(self, name: str, fieldname: str) -> str: if len(self.related_columns) == 1: return next(iter(self.related_columns.keys())) @@ -285,14 +293,15 @@ def from_fk_field_name(self, name: str, fieldname: str) -> str: def get_columns(self, name: str) -> Sequence[sqlalchemy.Column]: target = self.target columns = [] - for column_name, related_column in self.related_columns.items(): + for column_key, related_column in self.related_columns.items(): if related_column is None: - related_column = target.table.columns[column_name] - fkcolumn_name = self.get_fk_field_name(name, column_name) + related_column = target.table.columns[column_key] + fkcolumn_name = self.get_fk_field_name(name, column_key) # use the related column as reference fkcolumn = sqlalchemy.Column( - fkcolumn_name, - related_column.type, + key=fkcolumn_name, + type_=related_column.type, + name=self.get_fk_column_name(name, related_column.name), primary_key=self.primary_key, autoincrement=False, nullable=related_column.nullable or self.null, diff --git a/edgy/core/db/fields/many_to_many.py b/edgy/core/db/fields/many_to_many.py index d3ae1b6a..73635ff9 100644 --- a/edgy/core/db/fields/many_to_many.py +++ b/edgy/core/db/fields/many_to_many.py @@ -28,6 +28,7 @@ def __init__( from_fields: Sequence[str] = (), from_foreign_key: str = "", through: Union[str, type["BaseModelType"]] = "", + through_tablename: str = "", embed_through: Union[str, Literal[False]] = "", **kwargs: Any, ) -> None: @@ -37,6 +38,7 @@ def __init__( self.from_fields = from_fields self.from_foreign_key = from_foreign_key self.through = through + self.through_tablename = through_tablename self.embed_through = embed_through @cached_property @@ -185,7 +187,7 @@ def create_through_model(self) -> None: if not self.to_foreign_key: self.to_foreign_key = to_name.lower() - tablename = f"{owner_name.lower()}s_{to_name}s".lower() + tablename = self.through_tablename or f"{self.from_foreign_key}s_{self.to_foreign_key}s" meta_args = { "tablename": tablename, "registry": self.owner.meta.registry, diff --git a/edgy/core/db/querysets/base.py b/edgy/core/db/querysets/base.py index c1a1638b..aaf4b8e5 100644 --- a/edgy/core/db/querysets/base.py +++ b/edgy/core/db/querysets/base.py @@ -1,5 +1,4 @@ import asyncio -import copy import warnings from collections import defaultdict from collections.abc import AsyncIterator, Awaitable, Generator, Iterable, Sequence @@ -10,6 +9,7 @@ TYPE_CHECKING, Any, Callable, + Literal, Optional, Union, cast, @@ -43,6 +43,7 @@ generic_field = BaseField() +_empty_set = cast(Sequence[Any], frozenset()) def clean_query_kwargs( @@ -90,46 +91,89 @@ class BaseQuerySet( def __init__( self, model_class: Union[type[BaseModelType], None] = None, + *, database: Union["Database", None] = None, - filter_clauses: Any = None, - select_related: Any = None, - prefetch_related: Any = None, - limit_count: Any = None, - limit_offset: Any = None, + filter_clauses: Iterable[Any] = _empty_set, + select_related: Iterable[str] = _empty_set, + prefetch_related: Iterable["Prefetch"] = _empty_set, + limit_count: Optional[int] = None, + limit: Optional[int] = None, + limit_offset: Optional[int] = None, + offset: Optional[int] = None, batch_size: Optional[int] = None, - order_by: Any = None, - group_by: Any = None, - distinct_on: Optional[Sequence[str]] = None, - only_fields: Any = None, - defer_fields: Any = None, - embed_parent: Any = None, - embed_parent_filters: Any = None, - embed_sqla_row: str = "", - using_schema: Any = Undefined, - table: Any = None, + order_by: Iterable[str] = _empty_set, + group_by: Iterable[str] = _empty_set, + distinct_on: Union[None, Literal[True], Iterable[str]] = None, + distinct: Union[None, Literal[True], Iterable[str]] = None, + only_fields: Optional[Iterable[str]] = None, + only: Iterable[str] = _empty_set, + defer_fields: Optional[Sequence[str]] = None, + defer: Iterable[str] = _empty_set, + embed_parent: Optional[tuple[str, Union[str, str]]] = None, + embed_parent_filters: Optional[tuple[str, str]] = None, + using_schema: Union[str, None, Any] = Undefined, + table: Optional[sqlalchemy.Table] = None, exclude_secrets: bool = False, ) -> None: super().__init__(model_class=model_class) - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self.or_clauses: Any = [] - self.limit_count = limit_count - self._select_related = set([] if select_related is None else select_related) - self._prefetch_related = [] if prefetch_related is None else prefetch_related - self._offset = limit_offset + self.filter_clauses: list[Any] = list(filter_clauses) + self.or_clauses: list[Any] = [] + if limit_count is not None: + warnings.warn( + "`limit_count` is deprecated use `limit`", DeprecationWarning, stacklevel=2 + ) + limit = limit_count + self.limit_count = limit + if limit_offset is not None: + warnings.warn( + "`limit_offset` is deprecated use `limit`", DeprecationWarning, stacklevel=2 + ) + offset = limit_offset + self._offset = offset + + self._select_related = set(select_related) + self._prefetch_related = list(prefetch_related) self._batch_size = batch_size - self._order_by = [] if order_by is None else order_by - self._group_by = [] if group_by is None else group_by - self.distinct_on = distinct_on - self._only = set([] if only_fields is None else only_fields) - self._defer = set([] if defer_fields is None else defer_fields) + self._order_by: tuple[str, ...] = tuple(order_by) + self._group_by: tuple[str, ...] = tuple(group_by) + if distinct_on is not None: + warnings.warn( + "`distinct_on` is deprecated use `distinct`", DeprecationWarning, stacklevel=2 + ) + distinct = distinct_on + + if distinct is True: + distinct = _empty_set + self.distinct_on = list(distinct) if distinct is not None else None + if only_fields is not None: + warnings.warn( + "`only_fields` is deprecated use `only`", DeprecationWarning, stacklevel=2 + ) + only = only_fields + self._only = set(only) + if defer_fields is not None: + warnings.warn( + "`defer_fields` is deprecated use `defer`", DeprecationWarning, stacklevel=2 + ) + defer = defer_fields + self._defer = set(defer) self.embed_parent = embed_parent - self.using_schema = using_schema self.embed_parent_filters = embed_parent_filters + self.using_schema = using_schema self._exclude_secrets = exclude_secrets # cache should not be cloned self._cache = QueryModelResultCache(attrs=self.model_class.pkcolumns) # is empty self._clear_cache(False) + # this is not cleared, because the expression is immutable + self._cached_select_related_expression: Optional[ + tuple[ + str, + dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]], + dict[str, set[str]], + Any, + ] + ] = None # initialize self.active_schema = self.get_schema() @@ -140,16 +184,58 @@ def __init__( if database is not None: self.database = database + def _clone(self) -> "QuerySet": + """ + Return a copy of the current QuerySet that's ready for another + operation. + """ + queryset = self.__class__( + self.model_class, + database=self.database, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + prefetch_related=self._prefetch_related, + limit=self.limit_count, + offset=self._offset, + batch_size=self._batch_size, + order_by=self._order_by, + group_by=self._group_by, + distinct=self.distinct_on, + only=self._only, + defer=self._defer, + embed_parent=self.embed_parent, + embed_parent_filters=self.embed_parent_filters, + using_schema=self.using_schema, + table=getattr(self, "_table", None), + exclude_secrets=self._exclude_secrets, + ) + queryset.or_clauses = list(self.or_clauses) + queryset._cached_select_related_expression = self._cached_select_related_expression + return cast("QuerySet", queryset) + + def _clear_cache(self, keep_result_cache: bool = False) -> None: + if not keep_result_cache: + self._cache.clear() + self._cached_select_with_tables: Optional[ + tuple[Any, dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]]] + ] = None + self._cache_count: Optional[int] = None + self._cache_first: Optional[tuple[BaseModelType, BaseModelType]] = None + self._cache_last: Optional[tuple[BaseModelType, BaseModelType]] = None + # fetch all is in cache + self._cache_fetch_all: bool = False + # get current row during iteration. Used for prefetching. + # Bad style but no other way currently possible + self._cache_current_row: Optional[sqlalchemy.Row] = None + def _build_order_by_expression(self, order_by: Any, expression: Any) -> Any: """Builds the order by expression""" - order_by = list(map(self._prepare_order_by, order_by)) - expression = expression.order_by(*order_by) + expression = expression.order_by(*(self._prepare_order_by(entry) for entry in order_by)) return expression def _build_group_by_expression(self, group_by: Any, expression: Any) -> Any: """Builds the group by expression""" - group_by = list(map(self._prepare_group_by, group_by)) - expression = expression.group_by(*group_by) + expression = expression.group_by(*(self._prepare_order_by(entry) for entry in group_by)) return expression async def _resolve_clause_args(self, args: Any) -> Any: @@ -236,128 +322,135 @@ def _build_tables_select_from_relationship( # We pop out the transitions so a path is not taken 2 times # Why left outer join? It is possible and legal for a relation to not exist we check that already in filtering. - queryset: BaseQuerySet = self - select_from = queryset.table - maintablekey = select_from.key - tables_and_models: dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]] = { - select_from.key: (select_from, self.model_class) - } - prefixes: dict[str, set[str]] = defaultdict(set) - transitions: dict[tuple[str, str], tuple[Any, set[tuple[str, str]]]] = {} - transitions_is_full_outer: dict[tuple[str, str], bool] = {} - - # Select related - for select_path in queryset._select_related: - # For m2m relationships - model_class = queryset.model_class - former_table = queryset.table - former_transition = None - prefix: str = "" - model_database: Optional[Database] = queryset.database - while select_path: - field_name = select_path.split("__", 1)[0] - try: + if self._cached_select_related_expression is None: + maintable = self.table + select_from = maintable + maintablekey = maintable.key + tables_and_models: dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]] = { + select_from.key: (select_from, self.model_class) + } + prefixes: dict[str, set[str]] = defaultdict(set) + transitions: dict[tuple[str, str], tuple[Any, set[tuple[str, str]]]] = {} + transitions_is_full_outer: dict[tuple[str, str], bool] = {} + + # Select related + for select_path in self._select_related: + # For m2m relationships + model_class = self.model_class + former_table = maintable + former_transition = None + prefix: str = "" + model_database: Optional[Database] = self.database + while select_path: + field_name = select_path.split("__", 1)[0] + try: + field = model_class.meta.fields[field_name] + except KeyError: + raise QuerySetError( + detail=f'Selected field "{field_name}" does not exist on {model_class}.' + ) from None field = model_class.meta.fields[field_name] - except KeyError: - raise QuerySetError( - detail=f'Selected field "{field_name}" does not exist on {model_class}.' - ) from None - field = model_class.meta.fields[field_name] - if isinstance(field, RelationshipField): - model_class, reverse_part, select_path = field.traverse_field(select_path) - else: - raise QuerySetError( - detail=f'Selected field "{field_name}" is not a RelationshipField on {model_class}.' - ) - if isinstance(field, BaseForeignKey): - foreign_key = field - reverse = False - else: - foreign_key = model_class.meta.fields[reverse_part] - reverse = True - if foreign_key.is_cross_db(model_database): - raise QuerySetError( - detail=f'Selected model "{field_name}" is on another database.' - ) - # now use the one of the model_class itself - model_database = None - table = model_class.table_schema(self.active_schema) - # use table from tables_and_models - if table.key in tables_and_models: - table = tables_and_models[table.key][0] - - if foreign_key.is_m2m and foreign_key.embed_through != "": # type: ignore - # we need to inject the through model for the select - model_class = foreign_key.through - table = model_class.table_schema(self.active_schema) - if reverse: - select_path = f"{foreign_key.from_foreign_key}__{select_path}" + if isinstance(field, RelationshipField): + model_class, reverse_part, select_path = field.traverse_field(select_path) else: - select_path = f"{foreign_key.to_foreign_key}__{select_path}" - # if select_path is empty - select_path = select_path.removesuffix("__") - if reverse: - foreign_key = model_class.meta.fields[foreign_key.to_foreign_key] + raise QuerySetError( + detail=f'Selected field "{field_name}" is not a RelationshipField on {model_class}.' + ) + if isinstance(field, BaseForeignKey): + foreign_key = field + reverse = False else: - foreign_key = model_class.meta.fields[foreign_key.from_foreign_key] + foreign_key = model_class.meta.fields[reverse_part] reverse = True - prefix = f"{prefix}__{field_name}" if prefix else f"{prefix}" - prefixes[table.key].add(prefix) - transition_key = (former_table.key, table.key) - if transition_key in transitions: - # can not provide new informations - former_table = table - former_transition = transition_key - continue - and_clause = clauses_mod.and_( - *self._select_from_relationship_clause_generator( - foreign_key, table, reverse, former_table - ) - ) - if (table.key, former_table.key) in transitions: - _transition_key = (table.key, former_table.key) - # inverted - # only make full outer when not the main query - if former_table.key != maintablekey: - transitions_is_full_outer[_transition_key] = True - transitions[_transition_key] = ( - clauses_mod.or_(transitions[_transition_key][0], and_clause), - {*transitions[_transition_key][1], former_transition} - if former_transition - else transitions[_transition_key][1], + if foreign_key.is_cross_db(model_database): + raise QuerySetError( + detail=f'Selected model "{field_name}" is on another database.' + ) + # now use the one of the model_class itself + model_database = None + table = model_class.table_schema(self.active_schema) + # use table from tables_and_models + if table.key in tables_and_models: + table = tables_and_models[table.key][0] + + if foreign_key.is_m2m and foreign_key.embed_through != "": # type: ignore + # we need to inject the through model for the select + model_class = foreign_key.through + table = model_class.table_schema(self.active_schema) + if reverse: + select_path = f"{foreign_key.from_foreign_key}__{select_path}" + else: + select_path = f"{foreign_key.to_foreign_key}__{select_path}" + # if select_path is empty + select_path = select_path.removesuffix("__") + if reverse: + foreign_key = model_class.meta.fields[foreign_key.to_foreign_key] + else: + foreign_key = model_class.meta.fields[foreign_key.from_foreign_key] + reverse = True + prefix = f"{prefix}__{field_name}" if prefix else f"{prefix}" + prefixes[table.key].add(prefix) + transition_key = (former_table.key, table.key) + if transition_key in transitions: + # can not provide new informations + former_table = table + former_transition = transition_key + continue + and_clause = clauses_mod.and_( + *self._select_from_relationship_clause_generator( + foreign_key, table, reverse, former_table + ) ) - elif table.key in tables_and_models: - for _transition_key in transitions: - if _transition_key[1] == table.key: - break + if (table.key, former_table.key) in transitions: + _transition_key = (table.key, former_table.key) + # inverted + # only make full outer when not the main query + if former_table.key != maintablekey: + transitions_is_full_outer[_transition_key] = True + transitions[_transition_key] = ( + clauses_mod.or_(transitions[_transition_key][0], and_clause), + {*transitions[_transition_key][1], former_transition} + if former_transition + else transitions[_transition_key][1], + ) + elif table.key in tables_and_models: + for _transition_key in transitions: + if _transition_key[1] == table.key: + break + else: + # this should never happen + raise Exception("transition not found despite in tables_and_models") + transitions[_transition_key] = ( + clauses_mod.or_(and_clause, transitions[_transition_key][0]), + {*transitions[_transition_key][1], former_transition} + if former_transition + else transitions[_transition_key][0], + ) else: - # this should never happen - raise Exception("transition not found despite in tables_and_models") - transitions[_transition_key] = ( - clauses_mod.or_(and_clause, transitions[_transition_key][0]), - {*transitions[_transition_key][1], former_transition} - if former_transition - else transitions[_transition_key][0], - ) - else: - transitions[(former_table.key, table.key)] = ( - and_clause, - {former_transition} if former_transition else set(), - ) - tables_and_models[table.key] = table, model_class - former_table = table - former_transition = transition_key + transitions[(former_table.key, table.key)] = ( + and_clause, + {former_transition} if former_transition else set(), + ) + tables_and_models[table.key] = table, model_class + former_table = table + former_transition = transition_key - while transitions: - select_from = self._join_table_helper( + while transitions: + select_from = self._join_table_helper( + select_from, + next(iter(transitions.keys())), + transitions=transitions, + tables_and_models=tables_and_models, + transitions_is_full_outer=transitions_is_full_outer, + ) + self._cached_select_related_expression = ( + maintablekey, + tables_and_models, + prefixes, select_from, - next(iter(transitions.keys())), - transitions=transitions, - tables_and_models=tables_and_models, - transitions_is_full_outer=transitions_is_full_outer, ) - return maintablekey, tables_and_models, prefixes, select_from + return self._cached_select_related_expression @staticmethod def _select_from_relationship_clause_generator( @@ -381,7 +474,7 @@ def _validate_only_and_defer(self) -> None: if self._only and self._defer: raise QuerySetError("You cannot use .only() and .defer() at the same time.") - async def as_select_with_tables( + async def _as_select_with_tables( self, ) -> tuple[Any, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]]: """ @@ -468,6 +561,16 @@ async def as_select_with_tables( ) return expression, tables_and_models + async def as_select_with_tables( + self, + ) -> tuple[Any, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]]: + """ + Builds the query select based on the given parameters and filters. + """ + if self._cached_select_with_tables is None: + self._cached_select_with_tables = await self._as_select_with_tables() + return self._cached_select_with_tables + async def as_select( self, ) -> Any: @@ -478,7 +581,7 @@ def _kwargs_to_clauses( kwargs: Any, ) -> tuple[list[Any], set[str]]: clauses = [] - select_related = set(self._select_related) + select_related: set[str] = set() # Making sure for queries we use the main class and not the proxy # And enable the parent @@ -493,7 +596,7 @@ def _kwargs_to_clauses( model_class, field_name, op, related_str, _, cross_db_remainder = crawl_relationship( self.model_class, key ) - if related_str and related_str: + if related_str: select_related.add(related_str) field = model_class.meta.fields.get(field_name, generic_field) if cross_db_remainder: @@ -558,11 +661,6 @@ def _prepare_order_by(self, order_by: str) -> Any: order_col = self.table.columns[order_by] return order_col.desc() if reverse else order_col - def _prepare_group_by(self, group_by: str) -> Any: - group_by = group_by.lstrip("-") - group_col = self.table.columns[group_by] - return group_col - def _prepare_fields_for_distinct(self, distinct_on: str) -> sqlalchemy.Column: return self.table.columns[distinct_on] @@ -624,52 +722,7 @@ def get_schema(self) -> Optional[str]: schema = get_schema() if schema is None: schema = self.model_class.get_db_schema() - return schema # type: ignore - - def _clone(self) -> "QuerySet": - """ - Return a copy of the current QuerySet that's ready for another - operation. - """ - queryset = self.__class__.__new__(self.__class__) - queryset.model_class = self.model_class - queryset._cache = QueryModelResultCache(attrs=queryset.model_class.pkcolumns) - queryset._clear_cache() - queryset.using_schema = self.using_schema - - # initialize - queryset.active_schema = self.get_schema() - - queryset._table = getattr(self, "_table", None) - queryset.filter_clauses = list(self.filter_clauses) - queryset.or_clauses = list(self.or_clauses) - queryset.limit_count = copy.copy(self.limit_count) - queryset._select_related = set(self._select_related) - queryset._prefetch_related = copy.copy(self._prefetch_related) - queryset._offset = copy.copy(self._offset) - queryset._order_by = copy.copy(self._order_by) - queryset._group_by = copy.copy(self._group_by) - queryset.distinct_on = copy.copy(self.distinct_on) - queryset.embed_parent = self.embed_parent - queryset.embed_parent_filters = self.embed_parent_filters - queryset._batch_size = self._batch_size - queryset._only = set(self._only) - queryset._defer = set(self._defer) - queryset._database = self.database - queryset._exclude_secrets = self._exclude_secrets - return cast("QuerySet", queryset) - - def _clear_cache(self, keep_result_cache: bool = False) -> None: - if not keep_result_cache: - self._cache.clear() - self._cache_count: Optional[int] = None - self._cache_first: Optional[tuple[BaseModelType, BaseModelType]] = None - self._cache_last: Optional[tuple[BaseModelType, BaseModelType]] = None - # fetch all is in cache - self._cache_fetch_all: bool = False - # get current row during iteration. Used for prefetching. - # Bad style but no other way currently possible - self._cache_current_row: Optional[sqlalchemy.Row] = None + return schema async def _handle_batch( self, @@ -868,9 +921,10 @@ def _filter_or_exclude( ] = [] for raw_clause in clauses: if isinstance(raw_clause, dict): - extracted_clauses, queryset._select_related = queryset._kwargs_to_clauses( - kwargs=raw_clause - ) + extracted_clauses, related = queryset._kwargs_to_clauses(kwargs=raw_clause) + if not queryset._select_related.issuperset(related): + queryset._select_related.update(related) + queryset._cached_select_related_expression = None if or_ and extracted_clauses: async def wrapper_and( @@ -905,8 +959,9 @@ async def wrapper_and( raw_clause.model_class is queryset.model_class ), f"QuerySet arg has wrong model_class {raw_clause.model_class}" converted_clauses.append(raw_clause.build_where_clause) - for related in raw_clause._select_related: - queryset._select_related.add(related) + if not queryset._select_related.issuperset(raw_clause._select_related): + queryset._select_related.update(raw_clause._select_related) + queryset._cached_select_related_expression = None else: converted_clauses.append(raw_clause) @@ -964,22 +1019,20 @@ async def _get_raw(self, **kwargs: Any) -> tuple[BaseModelType, Any]: filter_query._cache = self._cache return await filter_query._get_raw() - queryset: BaseQuerySet = self - expression, tables_and_models = await queryset.limit(2).as_select_with_tables() - check_db_connection(queryset.database) - async with queryset.database as database: - rows = await database.fetch_all(expression) + expression, tables_and_models = await self.as_select_with_tables() + check_db_connection(self.database) + async with self.database as database: + # we want no queryset copy, so use sqlalchemy limit(2) + rows = await database.fetch_all(expression.limit(2)) if not rows: - queryset._cache_count = 0 + self._cache_count = 0 raise ObjectNotFound() if len(rows) > 1: raise MultipleObjectsReturned() - queryset._cache_count = 1 + self._cache_count = 1 - return await queryset._get_or_cache_row( - rows[0], tables_and_models, "_cache_first,_cache_last" - ) + return await self._get_or_cache_row(rows[0], tables_and_models, "_cache_first,_cache_last") class QuerySet(BaseQuerySet): @@ -1162,9 +1215,9 @@ def order_by(self, *order_by: str) -> "QuerySet": def reverse(self) -> "QuerySet": queryset: QuerySet = self._clone() - queryset._order_by = [ + queryset._order_by = tuple( el[1:] if el.startswith("-") else f"-{el}" for el in queryset._order_by - ] + ) return queryset def limit(self, limit_count: int) -> "QuerySet": @@ -1183,7 +1236,7 @@ def offset(self, offset: int) -> "QuerySet": queryset._offset = offset return queryset - def group_by(self, *group_by: Sequence[str]) -> "QuerySet": + def group_by(self, *group_by: str) -> "QuerySet": """ Returns the values grouped by the given fields. """ @@ -1191,12 +1244,17 @@ def group_by(self, *group_by: Sequence[str]) -> "QuerySet": queryset._group_by = group_by return queryset - def distinct(self, *distinct_on: str) -> "QuerySet": + def distinct(self, first: Union[bool, str] = True, *distinct_on: str) -> "QuerySet": """ Returns a queryset with distinct results. """ queryset: QuerySet = self._clone() - queryset.distinct_on = distinct_on + if first is False: + queryset.distinct_on = None + elif first is True: + queryset.distinct_on = [] + else: + queryset.distinct_on = [first, *distinct_on] return queryset def only(self, *fields: str) -> "QuerySet": @@ -1217,31 +1275,36 @@ def only(self, *fields: str) -> "QuerySet": queryset._only = only_fields return queryset - def defer(self, *fields: Sequence[str]) -> "QuerySet": + def defer(self, *fields: str) -> "QuerySet": """ Returns a list of models with the selected only fields and always the primary key. """ queryset: QuerySet = self._clone() - defer_fields = set(fields) - queryset._defer = defer_fields + queryset._defer = set(fields) return queryset - def select_related(self, related: Any) -> "QuerySet": + def select_related(self, *related: str) -> "QuerySet": """ Returns a QuerySet that will “follow” foreign-key relationships, selecting additional related-object data when it executes its query. This is a performance booster which results in a single more complex query but means - later use of foreign-key relationships won’t require database queries. + later use of foreign-key relationships won't require database queries. """ queryset: QuerySet = self._clone() - if not isinstance(related, (list, tuple)): - related = [related] - - queryset._select_related.update(related) + if len(related) >= 1 and not isinstance(cast(Any, related[0]), str): + warnings.warn( + "use `select_related` with variadic str arguments instead of a Sequence", + DeprecationWarning, + stacklevel=2, + ) + related = cast(tuple[str, ...], related[0]) + if not self._select_related.issuperset(related): + queryset._cached_select_related_expression = None + queryset._select_related.update(related) return queryset async def values( @@ -1354,6 +1417,7 @@ async def first(self) -> Union[EdgyEmbedTarget, None]: if not queryset._order_by: queryset = queryset.order_by(*self.model_class.pkcolumns) expression, tables_and_models = await queryset.as_select_with_tables() + self._cached_select_related_expression = queryset._cached_select_related_expression check_db_connection(queryset.database) async with queryset.database as database: row = await database.fetch_one(expression, pos=0) @@ -1374,7 +1438,9 @@ async def last(self) -> Union[EdgyEmbedTarget, None]: queryset = self if not queryset._order_by: queryset = queryset.order_by(*self.model_class.pkcolumns) - expression, tables_and_models = await queryset.reverse().as_select_with_tables() + queryset = queryset.reverse() + expression, tables_and_models = await queryset.as_select_with_tables() + self._cached_select_related_expression = queryset._cached_select_related_expression check_db_connection(queryset.database) async with queryset.database as database: row = await database.fetch_one(expression, pos=0) diff --git a/edgy/core/db/querysets/prefetch.py b/edgy/core/db/querysets/prefetch.py index abff64fa..70cadc7f 100644 --- a/edgy/core/db/querysets/prefetch.py +++ b/edgy/core/db/querysets/prefetch.py @@ -69,6 +69,5 @@ def prefetch_related(self, *prefetch: Prefetch) -> "QuerySet": if any(not isinstance(value, Prefetch) for value in prefetch): raise QuerySetError("The prefetch_related must have Prefetch type objects only.") - prefetch = [*self._prefetch_related, *prefetch] # type: ignore - queryset._prefetch_related = prefetch + queryset._prefetch_related = [*self._prefetch_related, *prefetch] return queryset diff --git a/edgy/core/db/querysets/types.py b/edgy/core/db/querysets/types.py index 16509961..dfd90be1 100644 --- a/edgy/core/db/querysets/types.py +++ b/edgy/core/db/querysets/types.py @@ -140,7 +140,7 @@ def exclude( def lookup(self, term: Any) -> "QueryType": ... @abstractmethod - def order_by(self, *columns: Union[list, str]) -> "QueryType": ... + def order_by(self, *columns: str) -> "QueryType": ... @abstractmethod def reverse(self) -> "QueryType": ... @@ -152,19 +152,19 @@ def limit(self, limit_count: int) -> "QueryType": ... def offset(self, offset: int) -> "QueryType": ... @abstractmethod - def group_by(self, group_by: Union[list, str]) -> "QueryType": ... + def group_by(self, *group_by: str) -> "QueryType": ... @abstractmethod def distinct(self, *distinct_on: Sequence[str]) -> "QueryType": ... @abstractmethod - def select_related(self, related: Union[list, str]) -> "QueryType": ... + def select_related(self, *related: str) -> "QueryType": ... @abstractmethod - def only(self, *fields: Sequence[str]) -> "QueryType": ... + def only(self, *fields: str) -> "QueryType": ... @abstractmethod - def defer(self, *fields: Sequence[str]) -> "QueryType": ... + def defer(self, *fields: str) -> "QueryType": ... @abstractmethod async def exists(self) -> bool: ... diff --git a/edgy/core/db/relationships/relation.py b/edgy/core/db/relationships/relation.py index 6d0a47af..e6e9aa39 100644 --- a/edgy/core/db/relationships/relation.py +++ b/edgy/core/db/relationships/relation.py @@ -58,7 +58,7 @@ def get_queryset(self) -> "QuerySet": query[related_name] = getattr(self.instance, related_name) queryset = queryset.filter(**{self.from_foreign_key: query}) # now set embed_parent - queryset.embed_parent = (self.to_foreign_key, self.embed_through) + queryset.embed_parent = (self.to_foreign_key, self.embed_through or "") if self.embed_through: queryset.embed_parent_filters = queryset.embed_parent return queryset diff --git a/tests/exclude_secrets/test_exclude_nested.py b/tests/exclude_secrets/test_exclude_nested.py index f3933b7c..164e840f 100644 --- a/tests/exclude_secrets/test_exclude_nested.py +++ b/tests/exclude_secrets/test_exclude_nested.py @@ -51,7 +51,7 @@ async def test_exclude_secrets_excludes_top_name_equals_to_name_in_foreignkey_no await Organisation.query.create(user=user) org_query = await ( - Organisation.query.select_related(["user__profile"]).exclude_secrets().order_by("id") + Organisation.query.select_related("user__profile").exclude_secrets().order_by("id") ).as_select() org_query_text = str(org_query) assert "profiles.name" in org_query_text @@ -65,9 +65,7 @@ async def test_exclude_secrets_excludes_top_name_equals_to_name_in_foreignkey_no ) await Organisation.query.create(user=user) - org_query = ( - Organisation.query.select_related(["user__profile"]).exclude_secrets().order_by("id") - ) + org_query = Organisation.query.select_related("user__profile").exclude_secrets().order_by("id") org = await org_query.last() assert org.model_dump() == { diff --git a/tests/fields/test_file_fields.py b/tests/fields/test_file_fields.py index 92693297..2de191c3 100644 --- a/tests/fields/test_file_fields.py +++ b/tests/fields/test_file_fields.py @@ -15,6 +15,13 @@ models = edgy.Registry(database=database) +class MyModäl(edgy.Model): + ä: edgy.files.FieldFile = edgy.fields.FileField(null=True, column_name="a") + + class Meta: + registry = models + + class MyModel(edgy.Model): file_field: edgy.files.FieldFile = edgy.fields.FileField(null=True) file_field_size: int = edgy.fields.IntegerField(null=True) @@ -118,6 +125,26 @@ async def test_save_file_create(create_test_database): assert not os.path.exists(path) +async def test_save_file_create_specal(create_test_database): + model = await MyModäl.query.create(ä=edgy.files.ContentFile(b"!# /bin/sh", name="foo.sh")) + # get cached + assert model.__dict__["ä"].__dict__["size"] == 10 + assert model.ä.size == 10 + # distro specific + assert model.ä.metadata["mime"].endswith("x-sh") + + assert model.ä.approved + with model.ä.open() as rob: + assert rob.read() == b"!# /bin/sh" + path = model.ä.path + assert os.path.exists(path) + assert model.ä.storage.exists(model.ä.name) + model.ä.delete() + assert os.path.exists(path) + await model.save() + assert not os.path.exists(path) + + async def test_save_file_available_overwrite(create_test_database): model1 = await MyModel.query.create( file_field=edgy.files.ContentFile(b"foo", name="foo.bytes") diff --git a/tests/foreign_keys/test_foreignkey_special.py b/tests/foreign_keys/test_foreignkey_special.py new file mode 100644 index 00000000..93926562 --- /dev/null +++ b/tests/foreign_keys/test_foreignkey_special.py @@ -0,0 +1,173 @@ +import pytest + +import edgy +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, full_isolation=False) +models = edgy.Registry(database=database) + + +class Älbum(edgy.Model): + äd = edgy.IntegerField(primary_key=True, column_name="id") + name = edgy.CharField(max_length=100) + + class Meta: + registry = models + tablename = "albums" + + +class Track(edgy.Model): + id = edgy.IntegerField(primary_key=True) + album = edgy.ForeignKey("Älbum", on_delete=edgy.CASCADE, null=True, column_name="album") + title = edgy.CharField(max_length=100) + position = edgy.IntegerField() + + class Meta: + registry = models + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + await models.create_all() + yield + await models.drop_all() + + +@pytest.fixture(autouse=True) +async def rollback_connections(): + with database.force_rollback(): + async with database: + yield + + +async def test_new_create(): + track1 = await Track.query.create(title="The Bird", position=1) + track2 = await Track.query.create(title="Heart don't stand a chance", position=2) + await Track.query.create(title="The Waters", position=3) + + album = await Älbum.query.create(name="Malibu") + await album.tracks_set.add(track1) + await album.tracks_set.add(track2) + tracks = await album.tracks_set.all() + assert len(tracks) == 2 + + await album.tracks_set.remove(track2) + tracks = await album.tracks_set.all() + assert len(tracks) == 1 + + +async def test_new_create2(): + track1 = await Track.query.create(title="The Bird", position=1) + track2 = await Track.query.create(title="Heart don't stand a chance", position=2) + await Track.query.create(title="The Waters", position=3) + + album = await Älbum.query.create(name="Malibu", tracks_set=[track1, track2]) + tracks = await album.tracks_set.all() + + assert len(tracks) == 2 + + +async def test_select_related(): + album = await Älbum.query.create(name="Malibu") + await Track.query.create(album=album, title="The Bird", position=1) + await Track.query.create(album=album, title="Heart don't stand a chance", position=2) + await Track.query.create(album=album, title="The Waters", position=3) + + fantasies = await Älbum.query.create(name="Fantasies") + await Track.query.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.query.create(album=fantasies, title="Sick Muse", position=2) + await Track.query.create(album=fantasies, title="Satellite Mind", position=3) + + track = await Track.query.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.query.select_related("album").all() + assert len(tracks) == 6 + + +async def test_select_related_no_all(): + album = await Älbum.query.create(name="Malibu") + await Track.query.create(album=album, title="The Bird", position=1) + await Track.query.create(album=album, title="Heart don't stand a chance", position=2) + await Track.query.create(album=album, title="The Waters", position=3) + + fantasies = await Älbum.query.create(name="Fantasies") + await Track.query.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.query.create(album=fantasies, title="Sick Muse", position=2) + await Track.query.create(album=fantasies, title="Satellite Mind", position=3) + + track = await Track.query.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.query.select_related("album") + assert len(tracks) == 6 + + +async def test_fk_filter(): + malibu = await Älbum.query.create(name="Malibu") + await Track.query.create(album=malibu, title="The Bird", position=1) + await Track.query.create(album=malibu, title="Heart don't stand a chance", position=2) + await Track.query.create(album=malibu, title="The Waters", position=3) + + fantasies = await Älbum.query.create(name="Fantasies") + await Track.query.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.query.create(album=fantasies, title="Sick Muse", position=2) + await Track.query.create(album=fantasies, title="Satellite Mind", position=3) + + tracks = await Track.query.select_related("album").filter(album__name="Fantasies").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.query.select_related("album").filter(album__name__icontains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.query.filter(album__name__icontains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.query.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu" + + +async def test_queryset_delete_with_fk(): + malibu = await Älbum.query.create(name="Malibu") + await Track.query.create(album=malibu, title="The Bird", position=1) + + wall = await Älbum.query.create(name="The Wall") + await Track.query.create(album=wall, title="The Wall", position=1) + + await Track.query.filter(album=malibu).delete() + assert await Track.query.filter(album=malibu).count() == 0 + assert await Track.query.filter(album=wall).count() == 1 + + +async def test_queryset_update_with_fk(): + malibu = await Älbum.query.create(name="Malibu") + wall = await Älbum.query.create(name="The Wall") + await Track.query.create(album=malibu, title="The Bird", position=1) + + await Track.query.filter(album=malibu).update(album=wall) + assert await Track.query.filter(album=malibu).count() == 0 + assert await Track.query.filter(album=wall).count() == 1 + + +@pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") +async def test_on_delete_cascade(): + album = await Älbum.query.create(name="The Wall") + await Track.query.create(album=album, title="Hey You", position=1) + await Track.query.create(album=album, title="Breathe", position=2) + + assert await Track.query.count() == 2 + + await album.delete() + + assert await Track.query.count() == 0 diff --git a/tests/foreign_keys/test_many_to_many_field_special.py b/tests/foreign_keys/test_many_to_many_field_special.py new file mode 100644 index 00000000..e0bcced4 --- /dev/null +++ b/tests/foreign_keys/test_many_to_many_field_special.py @@ -0,0 +1,63 @@ +import pytest + +import edgy +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, full_isolation=False) +models = edgy.Registry(database=database) + + +class Üser(edgy.Model): + name = edgy.CharField(max_length=100) + + class Meta: + registry = models + tablename = "u" + + +class Studio(edgy.Model): + name = edgy.CharField(max_length=255) + users = edgy.ManyToMany( + Üser, through_tablename="foo", to_foreign_key="usr", from_foreign_key="fromage" + ) + + class Meta: + registry = models + + +def test_check_tablename(): + assert Studio.meta.fields["users"].through.meta.tablename == "foo" + + +@pytest.fixture(autouse=True, scope="function") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +async def test_many_to_many_many_fields(): + user1 = await Üser.query.create(name="Charlie") + user2 = await Üser.query.create(name="Monica") + user3 = await Üser.query.create(name="Snoopy") + + studio = await Studio.query.create(name="Downtown Records") + + # Add users and albums to studio + await studio.users.add(user1) + await studio.users.add(user2) + await studio.users.add(user3) + + # Start querying + + total_users = await studio.users.all() + + assert len(total_users) == 3 + assert total_users[0].pk == user1.pk + assert total_users[1].pk == user2.pk + assert total_users[2].pk == user3.pk diff --git a/tests/foreign_keys/test_many_to_many_field_special_auto.py b/tests/foreign_keys/test_many_to_many_field_special_auto.py new file mode 100644 index 00000000..3ba47f54 --- /dev/null +++ b/tests/foreign_keys/test_many_to_many_field_special_auto.py @@ -0,0 +1,61 @@ +import pytest + +import edgy +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, full_isolation=False) +models = edgy.Registry(database=database) + + +class Üser(edgy.Model): + name = edgy.CharField(max_length=100) + + class Meta: + registry = models + tablename = "u" + + +class Studio(edgy.Model): + name = edgy.CharField(max_length=255) + users = edgy.ManyToMany(Üser, to_foreign_key="usr", from_foreign_key="fromage") + + class Meta: + registry = models + + +def test_check_tablename(): + assert Studio.meta.fields["users"].through.meta.tablename == "fromages_usrs" + + +@pytest.fixture(autouse=True, scope="function") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +async def test_many_to_many_many_fields(): + user1 = await Üser.query.create(name="Charlie") + user2 = await Üser.query.create(name="Monica") + user3 = await Üser.query.create(name="Snoopy") + + studio = await Studio.query.create(name="Downtown Records") + + # Add users and albums to studio + await studio.users.add(user1) + await studio.users.add(user2) + await studio.users.add(user3) + + # Start querying + + total_users = await studio.users.all() + + assert len(total_users) == 3 + assert total_users[0].pk == user1.pk + assert total_users[1].pk == user2.pk + assert total_users[2].pk == user3.pk diff --git a/tests/models/test_select_related_mul.py b/tests/models/test_select_related_mul.py index aab8ee04..02f239aa 100644 --- a/tests/models/test_select_related_mul.py +++ b/tests/models/test_select_related_mul.py @@ -75,7 +75,7 @@ async def test_select_related(): assert len(query) == 1 - query = await Permission.query.select_related(["designation", "module"]).all() + query = await Permission.query.select_related("designation", "module").all() assert len(query) == 1 assert query[0].pk == permission.pk @@ -92,7 +92,7 @@ async def test_select_related_without_relation(): assert len(query) == 2 - query = await Permission.query.select_related(["designation", "module"]).all() + query = await Permission.query.select_related("designation", "module").all() assert len(query) == 2 assert query[0].pk == permission.pk diff --git a/tests/models/test_select_related_nested.py b/tests/models/test_select_related_nested.py index 934ec281..8ed5b370 100644 --- a/tests/models/test_select_related_nested.py +++ b/tests/models/test_select_related_nested.py @@ -49,8 +49,11 @@ async def test_nested_with_not_optimal_select_related_exclude_secrets(): await Organisation.query.create(user=user) org_query = Organisation.query.exclude_secrets(True) + # by default _select_related is a set; for having an arbitary order provide a list org_query._select_related = ["user", "user", "user__profile"] + assert org_query._cached_select_related_expression is None org = await org_query.last() + assert org_query._cached_select_related_expression is not None assert org.model_dump() == { "user": {"id": 1, "profile": {"id": 1, "name": "edgy"}, "email": "user@dev.com"}, @@ -66,8 +69,11 @@ async def test_nested_with_not_optimal_select_related_all(): await Organisation.query.create(user=user) org_query = Organisation.query.all() + # by default _select_related is a set; for having an arbitary order provide a list org_query._select_related = ["user", "user", "user__profile"] + assert org_query._cached_select_related_expression is None org = await org_query.get() + assert org_query._cached_select_related_expression is not None assert org.model_dump() == { "user": {"id": 1, "profile": {"id": 1, "name": "edgy"}, "email": "user@dev.com"}, @@ -83,8 +89,11 @@ async def test_nested_with_not_optimal_select_related_all2(): await Organisation.query.create(user=user) org_query = Organisation.query.all() + # by default _select_related is a set; for having an arbitary order provide a list org_query._select_related = ["user__profile", "user", "user"] + assert org_query._cached_select_related_expression is None org = await org_query.get() + assert org_query._cached_select_related_expression is not None assert org.model_dump() == { "user": {"id": 1, "profile": {"id": 1, "name": "edgy"}, "email": "user@dev.com"}, diff --git a/tests/models/test_select_related_single.py b/tests/models/test_select_related_single.py index 6cbc549a..4242dbdd 100644 --- a/tests/models/test_select_related_single.py +++ b/tests/models/test_select_related_single.py @@ -96,7 +96,7 @@ async def test_select_related_without_relation(): assert len(query) == 2 - query = await Permission.query.select_related(["designation", "module"]).all() + query = await Permission.query.select_related("designation", "module").all() assert len(query) == 2 assert query[0].pk == permission.pk diff --git a/tests/tenancy/test_select_related_multiple.py b/tests/tenancy/test_select_related_multiple.py index 04c817d1..455fb8f5 100644 --- a/tests/tenancy/test_select_related_multiple.py +++ b/tests/tenancy/test_select_related_multiple.py @@ -88,7 +88,7 @@ async def test_select_related_tenant(): query = ( await Permission.query.using(schema=tenant.schema_name) - .select_related(["designation", "module"]) + .select_related("designation", "module") .all() )