diff --git a/docs/models.md b/docs/models.md index 4e6290d5..4bcb45cd 100644 --- a/docs/models.md +++ b/docs/models.md @@ -164,7 +164,7 @@ to copy a model class and optionally add it to an other registry. You can add it to a registry later by using: -`model_class.add_to_registry(registry, name="")` +`model_class.add_to_registry(registry, name="", database=None, replace_related_field=False)` In fact the last method is called when the registry parameter of `copy_edgy_model` is not `None`. diff --git a/docs/release-notes.md b/docs/release-notes.md index c41eb556..74c08814 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -14,16 +14,23 @@ hide: - Global constraints via meta. - Allow functional indexes. - Expose further parameters for UniqueConstraints. +- `no_copy` attribute for fields. ### Changes - Breaking: Factories pass now the kwargs as dict to get_pydantic_type, get_column_type, get_constraints. This allows now modifying the arguments passed down to the field. +- Breaking: init_fields_mapping doesn't initializes the field stats anymore. +- Breaking: model rebuilds are executed lazily when calling init_fields_mapping not when assigning fields manually anymore. ### Fixed - Indexes and unique_together worked only for fields with columns of the same name. - MigrateConfig has no get_registry_copy. +- Migrations have duplicate fks and crash. +- ContentTypes were not copyable. +- VirtualCascade was not automatically enabled. +- Improve lazyness by splitting in two variable sets. ## 0.21.2 diff --git a/edgy/contrib/contenttypes/fields.py b/edgy/contrib/contenttypes/fields.py index 77c1482b..e2d379ea 100644 --- a/edgy/contrib/contenttypes/fields.py +++ b/edgy/contrib/contenttypes/fields.py @@ -63,10 +63,8 @@ def __new__( # type: ignore to: Union[type["BaseModelType"], str] = "ContentType", on_delete: str = CASCADE, no_constraint: bool = False, - delete_orphan: bool = True, - default: Any = lambda owner: owner.meta.registry.get_model("ContentType")( - name=owner.__name__ - ), + remove_referenced: bool = True, + default: Any = lambda owner: owner.meta.registry.content_type(name=owner.__name__), **kwargs: Any, ) -> "BaseFieldType": return super().__new__( @@ -75,7 +73,7 @@ def __new__( # type: ignore default=default, on_delete=on_delete, no_constraint=no_constraint, - delete_orphan=delete_orphan, + remove_referenced=remove_referenced, **kwargs, ) diff --git a/edgy/contrib/contenttypes/models.py b/edgy/contrib/contenttypes/models.py index 673530bb..c6959651 100644 --- a/edgy/contrib/contenttypes/models.py +++ b/edgy/contrib/contenttypes/models.py @@ -19,7 +19,7 @@ class Meta: name: str = edgy.fields.CharField(max_length=100, default="", index=True) # set also the schema for tenancy support schema_name: str = edgy.CharField(max_length=63, null=True, index=True) - # can be a hash or similar. For checking collisions cross domain + # can be a hash or similar. Usefull for checking collisions cross domain collision_key: str = edgy.fields.CharField(max_length=255, null=True, unique=True) async def get_instance(self) -> edgy.Model: @@ -34,7 +34,10 @@ async def delete( self, skip_post_delete_hooks: bool = False, remove_referenced_call: bool = False ) -> None: reverse_name = f"reverse_{self.name.lower()}" - query = cast("QuerySet", getattr(self, reverse_name)) + referenced_obs = cast("QuerySet", getattr(self, reverse_name)) await super().delete(skip_post_delete_hooks=skip_post_delete_hooks) - if not remove_referenced_call and self.no_constraint: - await query.using(schema=self.schema_name).delete() + if ( + not remove_referenced_call + and self.meta.fields[reverse_name].foreign_key.force_cascade_deletion_relation + ): + await referenced_obs.using(schema=self.schema_name).delete() diff --git a/edgy/core/connection/registry.py b/edgy/core/connection/registry.py index bfd6b33d..fd5b58fb 100644 --- a/edgy/core/connection/registry.py +++ b/edgy/core/connection/registry.py @@ -3,7 +3,7 @@ import re import warnings from collections import defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from copy import copy as shallow_copy from functools import cached_property, partial from types import TracebackType @@ -132,7 +132,7 @@ def __init__( defaultdict(list) ) - self.extra: Mapping[str, Database] = { + self.extra: dict[str, Database] = { k: v if isinstance(v, Database) else Database(v) for k, v in extra.items() } self.metadata_by_url = MetaDataByUrlDict(registry=self) @@ -141,28 +141,37 @@ def __init__( self._set_content_type(with_content_type) def __copy__(self) -> "Registry": - _copy = Registry(self.database) - _copy.extra = self.extra - _copy.models = {key: val.copy_edgy_model(_copy) for key, val in self.models.items()} - _copy.reflected = {key: val.copy_edgy_model(_copy) for key, val in self.reflected.items()} - _copy.tenant_models = { - key: val.copy_edgy_model(_copy) for key, val in self.tenant_models.items() - } - _copy.pattern_models = { - key: val.copy_edgy_model(_copy) for key, val in self.pattern_models.items() - } - _copy.dbs_reflected = set(self.dbs_reflected) + content_type: Union[bool, type[BaseModelType]] = False if self.content_type is not None: try: - _copy.content_type = self.get_model("ContentType") + content_type2 = content_type = self.get_model( + "ContentType", include_content_type_attr=False + ).copy_edgy_model() + # cleanup content_type copy + for field_name in list(content_type2.meta.fields.keys()): + if field_name.startswith("reverse_"): + del content_type2.meta.fields[field_name] except LookupError: - _copy.content_type = self.content_type - # init callbacks - _copy._set_content_type(_copy.content_type) + content_type = self.content_type + _copy = Registry( + self.database, with_content_type=content_type, schema=self.db_schema, extra=self.extra + ) + for i in ["models", "reflected", "tenant_models", "pattern_models"]: + dict_models = getattr(_copy, i) + dict_models.update( + ( + (key, val.copy_edgy_model(_copy)) + for key, val in getattr(self, i).items() + if key not in dict_models + ) + ) + _copy.dbs_reflected = set(self.dbs_reflected) return _copy def _set_content_type( - self, with_content_type: Union[Literal[True], type["BaseModelType"]] + self, + with_content_type: Union[Literal[True], type["BaseModelType"]], + old_content_type_to_replace: Optional[type["BaseModelType"]] = None, ) -> None: from edgy.contrib.contenttypes.fields import BaseContentTypeFieldField, ContentTypeField from edgy.contrib.contenttypes.models import ContentType @@ -197,31 +206,52 @@ def callback(model_class: type["BaseModelType"]) -> None: # they are not updated, despite this shouldn't happen anyway if issubclass(model_class, ContentType): return - # skip if is explicit set + # skip if is explicit set or remove when copying for field in model_class.meta.fields.values(): if isinstance(field, BaseContentTypeFieldField): + if ( + old_content_type_to_replace is not None + and field.target is old_content_type_to_replace + ): + field.target_registry = self + field.target = real_content_type + # simply overwrite + real_content_type.meta.fields[field.related_name] = RelatedField( + name=field.related_name, + foreign_key_name=field.name, + related_from=model_class, + owner=real_content_type, + ) return + # e.g. exclude field - if "content_type" not in model_class.meta.fields: - related_name = f"reverse_{model_class.__name__.lower()}" - assert ( - related_name not in real_content_type.meta.fields - ), f"duplicate model name: {model_class.__name__}" - model_class.meta.fields["content_type"] = cast( - "BaseFieldType", - ContentTypeField( - name="content_type", - owner=model_class, - to=real_content_type, - no_constraint=real_content_type.no_constraint, - ), - ) - real_content_type.meta.fields[related_name] = RelatedField( - name=related_name, - foreign_key_name="content_type", - related_from=model_class, - owner=real_content_type, - ) + if "content_type" in model_class.meta.fields: + return + related_name = f"reverse_{model_class.__name__.lower()}" + assert ( + related_name not in real_content_type.meta.fields + ), f"duplicate model name: {model_class.__name__}" + + field_args: dict[str, Any] = { + "name": "content_type", + "owner": model_class, + "to": real_content_type, + "no_constraint": real_content_type.no_constraint, + "no_copy": True, + } + if model_class.meta.registry is not real_content_type.meta.registry: + field_args["relation_has_post_delete_callback"] = True + field_args["force_cascade_deletion_relation"] = True + model_class.meta.fields["content_type"] = cast( + "BaseFieldType", + ContentTypeField(**field_args), + ) + real_content_type.meta.fields[related_name] = RelatedField( + name=related_name, + foreign_key_name="content_type", + related_from=model_class, + owner=real_content_type, + ) self.register_callback(None, callback, one_time=False) @@ -248,7 +278,15 @@ def metadata(self) -> sqlalchemy.MetaData: ) return self.metadata_by_name[None] - def get_model(self, model_name: str) -> type["BaseModelType"]: + def get_model( + self, model_name: str, *, include_content_type_attr: bool = True + ) -> type["BaseModelType"]: + if ( + include_content_type_attr + and model_name == "ContentType" + and self.content_type is not None + ): + return self.content_type if model_name in self.models: return self.models[model_name] elif model_name in self.reflected: diff --git a/edgy/core/db/fields/base.py b/edgy/core/db/fields/base.py index 91edbe8b..d98880c7 100644 --- a/edgy/core/db/fields/base.py +++ b/edgy/core/db/fields/base.py @@ -9,6 +9,7 @@ Literal, Optional, Union, + cast, ) import sqlalchemy @@ -418,11 +419,18 @@ class BaseForeignKey(RelationshipField): # only useful if related_name = False because otherwise it gets overwritten reverse_name: str = "" - @cached_property + @property def target_registry(self) -> "Registry": """Registry searched in case to is a string""" - assert self.owner.meta.registry, "no registry found neither 'target_registry' set" - return self.owner.meta.registry + + if not hasattr(self, "_target_registry"): + assert self.owner.meta.registry, "no registry found neither 'target_registry' set" + return self.owner.meta.registry + return cast("Registry", self._target_registry) + + @target_registry.setter + def target_registry(self, value: Any) -> None: + self._target_registry = value @property def target(self) -> Any: diff --git a/edgy/core/db/fields/foreign_keys.py b/edgy/core/db/fields/foreign_keys.py index e0cd5b23..924f42e2 100644 --- a/edgy/core/db/fields/foreign_keys.py +++ b/edgy/core/db/fields/foreign_keys.py @@ -12,7 +12,7 @@ import sqlalchemy from pydantic import BaseModel -from edgy.core.db.constants import CASCADE, SET_DEFAULT, SET_NULL +from edgy.core.db.constants import SET_DEFAULT, SET_NULL from edgy.core.db.context_vars import CURRENT_PHASE from edgy.core.db.fields.base import BaseForeignKey from edgy.core.db.fields.factories import ForeignKeyFieldFactory @@ -70,10 +70,6 @@ def __init__( ) if self.on_delete == SET_NULL and not self.null: terminal.write_warning("Declaring on_delete `SET NULL` but null is False.") - if self.force_cascade_deletion_relation or ( - self.on_delete == CASCADE and self.no_constraint - ): - self.relation_has_post_delete_callback = True async def _notset_post_delete_callback(self, value: Any, instance: "BaseModelType") -> None: value = self.expand_relationship(value) @@ -104,9 +100,7 @@ async def pre_save_callback( def get_relation(self, **kwargs: Any) -> ManyRelationProtocol: if self.relation_fn is not None: return self.relation_fn(**kwargs) - if self.force_cascade_deletion_relation or ( - self.on_delete == CASCADE and self.no_constraint - ): + if self.force_cascade_deletion_relation: relation: Any = VirtualCascadeDeletionSingleRelation else: relation = SingleRelation diff --git a/edgy/core/db/fields/types.py b/edgy/core/db/fields/types.py index 44f48492..d05bca45 100644 --- a/edgy/core/db/fields/types.py +++ b/edgy/core/db/fields/types.py @@ -47,6 +47,7 @@ class ColumnDefinitionModel( class BaseFieldDefinitions: + no_copy: bool = False read_only: bool = False inject_default_on_partial_update: bool = False inherit: bool = True diff --git a/edgy/core/db/models/base.py b/edgy/core/db/models/base.py index 26be5f9e..b51a00bc 100644 --- a/edgy/core/db/models/base.py +++ b/edgy/core/db/models/base.py @@ -163,7 +163,7 @@ def identifying_db_fields(self) -> Any: """The columns used for loading, can be set per instance defaults to pknames""" return self.pkcolumns - @cached_property + @property def proxy_model(self) -> type[Model]: return self.__class__.proxy_model # type: ignore diff --git a/edgy/core/db/models/metaclasses.py b/edgy/core/db/models/metaclasses.py index e735a611..d99d3eef 100644 --- a/edgy/core/db/models/metaclasses.py +++ b/edgy/core/db/models/metaclasses.py @@ -25,7 +25,7 @@ from edgy.core.connection.registry import Registry from edgy.core.db import fields as edgy_fields from edgy.core.db.datastructures import Index, UniqueConstraint -from edgy.core.db.fields.base import PKField, RelationshipField +from edgy.core.db.fields.base import BaseForeignKey, PKField, RelationshipField from edgy.core.db.fields.foreign_keys import BaseForeignKeyField from edgy.core.db.fields.types import BaseFieldType from edgy.core.db.models.managers import BaseManager @@ -49,7 +49,7 @@ def __init__(self, meta: MetaInfo, data: Optional[dict[str, BaseFieldType]] = No super().__init__(data) def add_field_to_meta(self, name: str, field: BaseFieldType) -> None: - if not self.meta._fields_are_initialized: + if not self.meta._field_stats_are_initialized: return if hasattr(field, "__get__"): self.meta.special_getter_fields.add(name) @@ -71,7 +71,7 @@ def add_field_to_meta(self, name: str, field: BaseFieldType) -> None: self.meta.relationship_fields.add(name) def discard_field_from_meta(self, name: str) -> None: - if self.meta._fields_are_initialized: + if self.meta._field_stats_are_initialized: for field_attr in _field_sets_to_clear: getattr(self.meta, field_attr).discard(name) @@ -94,13 +94,14 @@ def __setitem__(self, name: str, value: BaseFieldType) -> None: self.add_field_to_meta(name, value) if self.meta.model is not None: self.meta.model.model_fields[name] = value # type: ignore - self.meta.model.model_rebuild(force=True) - self.meta.invalidate(clear_class_attrs=True) + self.meta.invalidate(invalidate_stats=False) def __delitem__(self, name: str) -> None: if self.data.pop(name, None) is not None: self.discard_field_from_meta(name) - self.meta.invalidate(clear_class_attrs=True) + if self.meta.model is not None: + self.meta.model.model_fields.pop(name, None) # type: ignore + self.meta.invalidate(invalidate_stats=False) class FieldToColumns(UserDict, dict[str, Sequence["sqlalchemy.Column"]]): @@ -182,11 +183,14 @@ def __iter__(self) -> Any: return super().__iter__() -_trigger_attributes_MetaInfo = { +_trigger_attributes_fields_MetaInfo = { "field_to_columns", "field_to_column_names", - "foreign_key_fields", "columns_to_field", +} + +_trigger_attributes_field_stats_MetaInfo = { + "foreign_key_fields", "special_getter_fields", "input_modifying_fields", "post_save_fields", @@ -197,9 +201,7 @@ def __iter__(self) -> Any: "relationship_fields", } -_field_sets_to_clear: set[str] = { - attr for attr in _trigger_attributes_MetaInfo if attr.endswith("_fields") -} +_field_sets_to_clear: set[str] = _trigger_attributes_field_stats_MetaInfo class MetaInfo: @@ -229,6 +231,7 @@ class MetaInfo: "secret_fields", "relationship_fields", "_fields_are_initialized", + "_field_stats_are_initialized", ) _include_dump = ( *filter( @@ -238,6 +241,7 @@ class MetaInfo: "field_to_column_names", "columns_to_field", "_fields_are_initialized", + "_field_stats_are_initialized", }, __slots__, ), @@ -255,6 +259,7 @@ class MetaInfo: def __init__(self, meta: Any = None, **kwargs: Any) -> None: self._fields_are_initialized = False + self._field_stats_are_initialized = False self.model: Optional[type[BaseModelType]] = None # Difference between meta extraction and kwargs: meta attributes are copied self.abstract: bool = getattr(meta, "abstract", False) @@ -329,11 +334,25 @@ def __setattr__(self, name: str, value: Any) -> None: def __getattribute__(self, name: str) -> Any: # lazy execute - if name in _trigger_attributes_MetaInfo and not self._fields_are_initialized: + if name in _trigger_attributes_fields_MetaInfo and not self._fields_are_initialized: self.init_fields_mapping() + if ( + name in _trigger_attributes_field_stats_MetaInfo + and not self._field_stats_are_initialized + ): + self.init_field_stats() return super().__getattribute__(name) def init_fields_mapping(self) -> None: + # when accessing the secret_fields in model_dump, this is triggered + self.field_to_columns = FieldToColumns(self) + self.field_to_column_names = FieldToColumnNames(self) + self.columns_to_field = ColumnsToField(self) + if self.model is not None: + self.model.model_rebuild(force=True) + self._fields_are_initialized = True + + def init_field_stats(self) -> None: self.special_getter_fields: set[str] = set() self.excluded_fields: set[str] = set() self.secret_fields: set[str] = set() @@ -343,14 +362,16 @@ def init_fields_mapping(self) -> None: self.post_delete_fields: set[str] = set() self.foreign_key_fields: set[str] = set() self.relationship_fields: set[str] = set() - self.field_to_columns = FieldToColumns(self) - self.field_to_column_names = FieldToColumnNames(self) - self.columns_to_field = ColumnsToField(self) - self._fields_are_initialized = True + self._field_stats_are_initialized = True for key, field in self.fields.items(): self.fields.add_field_to_meta(key, field) - def invalidate(self, clear_class_attrs: bool = True, invalidate_fields: bool = True) -> None: + def invalidate( + self, + clear_class_attrs: bool = True, + invalidate_fields: bool = True, + invalidate_stats: bool = True, + ) -> None: if invalidate_fields and self._fields_are_initialized: # prevent cycles and mem-leaks for attr in ( @@ -361,6 +382,8 @@ def invalidate(self, clear_class_attrs: bool = True, invalidate_fields: bool = T with contextlib.suppress(AttributeError): delattr(self, attr) self._fields_are_initialized = False + if invalidate_stats: + self._field_stats_are_initialized = False if self.model is None: return if clear_class_attrs: @@ -371,6 +394,8 @@ def invalidate(self, clear_class_attrs: bool = True, invalidate_fields: bool = T def full_init(self, init_column_mappers: bool = True, init_class_attrs: bool = True) -> None: if not self._fields_are_initialized: self.init_fields_mapping() + if not self._field_stats_are_initialized: + self.init_field_stats() if init_column_mappers: self.columns_to_field.init() if init_class_attrs: @@ -644,6 +669,7 @@ def __new__( primary_key=True, autoincrement=True, inherit=False, + no_copy=True, name="id", ) if not isinstance(fields["id"], BaseFieldType) or not fields["id"].primary_key: @@ -651,8 +677,11 @@ def __new__( f"Cannot create model {name} without explicit primary key if field 'id' is already present." ) - for field_name in fields: + for field_name, field_value in fields.items(): attrs.pop(field_name, None) + # clear cached target + if isinstance(field_value, BaseForeignKey): + field_value.__dict__.pop("_target", None) for manager_name in managers: attrs.pop(manager_name, None) @@ -669,7 +698,7 @@ def __new__( if not meta.abstract: # don't add to model_fields, it leads to crashes for unknown reasons - meta.fields["pk"] = PKField(exclude=True, name="pk", inherit=False) + meta.fields["pk"] = PKField(exclude=True, name="pk", inherit=False, no_copy=True) # Handle annotations annotations: dict[str, Any] = handle_annotations(bases, base_annotations, attrs) @@ -882,7 +911,7 @@ def proxy_model(cls: type[Model]) -> type[Model]: """ if cls.__is_proxy_model__: return cls - if cls.__proxy_model__ is None: + if getattr(cls, "__proxy_model__", None) is None: proxy_model = cls.generate_proxy_model() proxy_model.__parent__ = cls proxy_model.model_rebuild(force=True) diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index 0ce93953..a978ca43 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -5,11 +5,12 @@ from collections.abc import Sequence from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union, cast import sqlalchemy from pydantic import BaseModel +from edgy.core.db.constants import CASCADE from edgy.core.db.context_vars import ( CURRENT_INSTANCE, EXPLICIT_SPECIFIED_VALUES, @@ -38,15 +39,32 @@ _empty = cast(set[str], frozenset()) +class _EmptyClass: ... + + +_removed_copy_keys = { + *BaseModel.__dict__.keys(), + "_loaded_or_deleted", + "_pkcolumns", + "_pknames", + "_table", + "_db_schemas", + "meta", +} +_removed_copy_keys.difference_update( + {*_EmptyClass.__dict__.keys(), "__annotations__", "__module__"} +) + + def _set_related_field( target: type["BaseModelType"], *, foreign_key_name: str, related_name: str, source: type["BaseModelType"], - replace_related_field: bool, + replace_related_field: Union[bool, type["BaseModelType"]], ) -> None: - if not replace_related_field and related_name in target.meta.fields: + if replace_related_field is not True and related_name in target.meta.fields: # is already correctly set, required for migrate of model_apps with registry set related_field = target.meta.fields[related_name] if ( @@ -54,9 +72,22 @@ def _set_related_field( and related_field.foreign_key_name == foreign_key_name ): return - raise ForeignKeyBadConfigured( - f"Multiple related_name with the same value '{related_name}' found to the same target. Related names must be different." - ) + # required for copying + if ( + related_field.related_from is not replace_related_field + or related_field.foreign_key_name != foreign_key_name + ): + raise ForeignKeyBadConfigured( + f"Multiple related_name with the same value '{related_name}' found to the same target. Related names must be different." + ) + # now we have enough data + fk = source.meta.fields[foreign_key_name] + if fk.force_cascade_deletion_relation or ( + fk.on_delete == CASCADE + and (source.meta.registry is not target.meta.registry or fk.no_constraint) + ): + fk.relation_has_post_delete_callback = True + fk.force_cascade_deletion_relation = True related_field = RelatedField( foreign_key_name=foreign_key_name, @@ -70,7 +101,9 @@ def _set_related_field( def _set_related_name_for_foreign_keys( - meta: "MetaInfo", model_class: type["BaseModelType"], replace_related_field: bool = False + meta: "MetaInfo", + model_class: type["BaseModelType"], + replace_related_field: Union[bool, type["BaseModelType"]] = False, ) -> None: """ Sets the related name for the foreign keys. @@ -103,6 +136,8 @@ def _set_related_name_for_foreign_keys( related_name=related_name, replace_related_field=replace_related_field, ) + # foreign_key.__dict__.pop("target", None) + # foreign_key.__dict__.pop("target_registry", None) registry: Registry = foreign_key.target_registry with contextlib.suppress(Exception): registry = cast("Registry", foreign_key.target.registry) @@ -110,12 +145,15 @@ def _set_related_name_for_foreign_keys( class DatabaseMixin: + _removed_copy_keys: ClassVar[set[str]] = _removed_copy_keys + @classmethod def add_to_registry( cls: type["BaseModelType"], registry: "Registry", name: str = "", database: Union[bool, "Database", Literal["keep"]] = "keep", + replace_related_field: Union[bool, type["BaseModelType"]] = False, ) -> None: # when called if registry is not set cls.meta.registry = registry @@ -152,7 +190,9 @@ def create_through_model(x: Any, field: "BaseFieldType" = value) -> None: m2m_registry.register_callback(value.to, create_through_model, one_time=True) # Sets the foreign key fields if meta.foreign_key_fields: - _set_related_name_for_foreign_keys(meta, cls) + _set_related_name_for_foreign_keys( + meta, cls, replace_related_field=replace_related_field + ) registry.execute_model_callbacks(cls) # finalize @@ -186,23 +226,33 @@ def copy_edgy_model( """Copy the model class and optionally add it to another registry.""" # removes private pydantic stuff, except the prefixed ones attrs = { - key: val - for key, val in cls.__dict__.items() - if key not in BaseModel.__dict__ or key.startswith("__") + key: val for key, val in cls.__dict__.items() if key not in cls._removed_copy_keys } - attrs.pop("meta", None) # managers and fields are gone, we have to readd them with the correct data - attrs.update(cls.meta.fields) + attrs.update( + ( + (field_name, field) + for field_name, field in cls.meta.fields.items() + if not field.no_copy + ) + ) attrs.update(cls.meta.managers) _copy = cast( type["Model"], type(cls.__name__, cls.__bases__, attrs, skip_registry=True, **kwargs), ) - _copy.meta.model = _copy + for field_name in _copy.meta.foreign_key_fields: + # we need to unreference and check if both models are in the same registry + if cls.meta.fields[field_name].target.meta.registry is cls.meta.registry: + _copy.meta.fields[field_name].target = cls.meta.fields[field_name].target.__name__ + else: + # otherwise we need to disable backrefs + _copy.meta.fields[field_name].target.related_name = False if name: _copy.__name__ = name if registry is not None: - _copy.add_to_registry(registry) + # replace when old class otherwise old references can lead to issues + _copy.add_to_registry(registry, replace_related_field=cls) return _copy @property diff --git a/tests/contrib/contenttypes/test_contenttypes.py b/tests/contrib/contenttypes/test_contenttypes.py index a148a5e6..25faf3f3 100644 --- a/tests/contrib/contenttypes/test_contenttypes.py +++ b/tests/contrib/contenttypes/test_contenttypes.py @@ -43,7 +43,7 @@ class Meta: class Person(edgy.StrictModel): first_name = edgy.fields.CharField(max_length=100) last_name = edgy.fields.CharField(max_length=100) - # to defaults to ContentType + # to defaults to registry.content_type c = ContentTypeField() class Meta: @@ -66,6 +66,22 @@ async def rollback_transactions(): yield +async def test_registry_sanity(): + assert models.content_type is models.get_model("ContentType", include_content_type_attr=False) + assert Company.meta.fields["content_type"].on_delete == "CASCADE" + assert models.get_model("Company") is Company + assert models.content_type.meta.fields["reverse_company"].related_from is Company + _copy = models.__copy__() + assert models.get_model("Company") is Company + assert models.content_type.meta.fields["reverse_company"].related_from is Company + assert _copy.get_model("Company") is not Company + assert _copy.get_model("Company").meta is not Company.meta + assert models.content_type is not _copy.content_type + assert models.content_type is not _copy.get_model( + "ContentType", include_content_type_attr=False + ) + + async def test_default_contenttypes(): model1 = await Company.query.create(name="edgy inc") model2 = await Organisation.query.create(name="edgy inc") @@ -100,6 +116,7 @@ async def test_different_named_contenttypes(): with pytest.raises(AttributeError): model1.content_type # noqa model_after_load = await Person.query.get(id=model1.id) + assert model_after_load == model1 assert model_after_load.c.id is not None # defer assert model_after_load.c.name == "Person" @@ -122,7 +139,8 @@ async def test_explicit_contenttypes(): assert model_after_load.content_type.id is not None # defer assert model_after_load.content_type.name == "Company" - assert await model_after_load.content_type.get_instance() == model1 + loaded = await model_after_load.content_type.get_instance() + assert loaded == model1 # count assert await models.content_type.query.count() == 2 await models.content_type.query.delete() diff --git a/tests/contrib/contenttypes/test_contenttypes_custom.py b/tests/contrib/contenttypes/test_contenttypes_custom.py index 75588b39..4d6b24e3 100644 --- a/tests/contrib/contenttypes/test_contenttypes_custom.py +++ b/tests/contrib/contenttypes/test_contenttypes_custom.py @@ -53,7 +53,7 @@ class Meta: class Person(edgy.StrictModel): first_name = edgy.fields.CharField(max_length=100) last_name = edgy.fields.CharField(max_length=100) - # to defaults to ContentType + # to defaults to registry.content_type c = ContentTypeField() class Meta: @@ -77,7 +77,7 @@ async def rollback_transactions(): async def test_registry_sanity(): - assert models.content_type is models.get_model("ContentType") + assert models.content_type is models.get_model("ContentType", include_content_type_attr=False) assert "custom_field" in models.content_type.meta.fields diff --git a/tests/contrib/contenttypes/test_contenttypes_custom_non_abstract.py b/tests/contrib/contenttypes/test_contenttypes_custom_non_abstract.py index 0ed1c53c..6bccd780 100644 --- a/tests/contrib/contenttypes/test_contenttypes_custom_non_abstract.py +++ b/tests/contrib/contenttypes/test_contenttypes_custom_non_abstract.py @@ -49,7 +49,7 @@ class Meta: class Person(edgy.StrictModel): first_name = edgy.fields.CharField(max_length=100) last_name = edgy.fields.CharField(max_length=100) - # to defaults to ContentType + # to defaults to registry.content_type c = ContentTypeField() class Meta: @@ -73,7 +73,7 @@ async def rollback_transactions(): async def test_registry_sanity(): - assert models.content_type is models.get_model("ContentType") + assert models.content_type is models.get_model("ContentType", include_content_type_attr=False) assert "custom_field" in models.content_type.meta.fields diff --git a/tests/contrib/contenttypes/test_contenttypes_different_registry.py b/tests/contrib/contenttypes/test_contenttypes_different_registry.py new file mode 100644 index 00000000..58145445 --- /dev/null +++ b/tests/contrib/contenttypes/test_contenttypes_different_registry.py @@ -0,0 +1,136 @@ +import asyncio + +import pytest +from sqlalchemy.exc import IntegrityError + +import edgy +from edgy.contrib.contenttypes.models import ContentType as _ContentType +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_ALTERNATIVE_URL, DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, use_existing=False) +database2 = DatabaseTestClient(DATABASE_ALTERNATIVE_URL, use_existing=False) + + +class ExplicitContentType(_ContentType): + class Meta: + abstract = True + + +nother = edgy.Registry( + database=edgy.Database(database2, force_rollback=True), with_content_type=ExplicitContentType +) + + +class ContentTypeTag(edgy.StrictModel): + ctype = edgy.fields.ForeignKey(to="ContentType", related_name="tags", on_delete=edgy.CASCADE) + tag = edgy.fields.CharField(max_length=50) + + content_type = edgy.fields.ExcludeField() + + class Meta: + registry = nother + + +models = edgy.Registry( + database=edgy.Database(database, force_rollback=True), with_content_type=nother.content_type +) + + +class Organisation(edgy.StrictModel): + name = edgy.fields.CharField(max_length=100, unique=True) + + class Meta: + registry = models + + +class Company(edgy.StrictModel): + name = edgy.fields.CharField(max_length=100, unique=True) + + class Meta: + registry = models + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + async with database, database2: + await nother.create_all() + await models.create_all() + yield + if not database.drop: + await models.drop_all() + if not database2.drop: + await nother.drop_all() + + +@pytest.fixture(autouse=True, scope="function") +async def rollback_transactions(): + async with models, nother: + yield + + +async def test_registry_sanity(): + assert models.content_type is nother.get_model("ContentType", include_content_type_attr=False) + + +async def test_default_contenttypes(): + model1 = await Company.query.create(name="edgy inc") + model2 = await Organisation.query.create(name="edgy inc") + assert model1.content_type.id is not None + assert model1.content_type.name == "Company" + assert model2.content_type.id is not None + assert model2.content_type.name == "Organisation" + tag = await model2.content_type.tags.add({"tag": "foo"}) + with pytest.raises(ValueError): + tag.content_type # noqa + model_after_load = await Company.query.get(id=model1.id) + assert model_after_load.content_type.id is not None + # defer + assert model_after_load.content_type.name == "Company" + assert await model_after_load.content_type.get_instance() == model1 + # fetch_all + [await content_type.get_instance() for content_type in await models.content_type.query.all()] + # iterate + with pytest.warns(UserWarning): + ops = [ + content_type.get_instance() async for content_type in models.content_type.query.all() + ] + await asyncio.gather(*ops) + await models.content_type.query.delete() + assert await Company.query.get_or_none(name="edgy inc") is None + + +async def test_explicit_contenttypes(): + # no name but still should work + model1 = await Company.query.create(name="edgy inc", content_type={}) + # wrong name type, but should be autocorrected + model2 = await Organisation.query.create(name="edgy inc", content_type={"name": "Company"}) + assert model1.content_type.id is not None + assert model1.content_type.name == "Company" + assert model2.content_type.id is not None + assert model2.content_type.name == "Organisation" + tag = await model2.content_type.tags.add({"tag": "foo"}) + with pytest.raises(ValueError): + tag.content_type # noqa + model_after_load = await Company.query.get(id=model1.id) + assert model_after_load.content_type.id is not None + # defer + assert model_after_load.content_type.name == "Company" + assert await model_after_load.content_type.get_instance() == model1 + await models.content_type.query.delete() + assert await Company.query.get_or_none(name="edgy inc") is None + + +async def test_collision(): + assert await Company.query.count() == 0 + model1 = await Company.query.create( + name="edgy inc", content_type={"collision_key": "edgy inc"} + ) + assert model1.content_type.collision_key == "edgy inc" + with pytest.raises(IntegrityError): + await Organisation.query.create( + name="edgy inc", content_type={"collision_key": "edgy inc"} + ) + assert await Organisation.query.count() == 0 diff --git a/tests/contrib/contenttypes/test_contenttypes_iterate.py b/tests/contrib/contenttypes/test_contenttypes_iterate.py index dc7cdd32..3268010b 100644 --- a/tests/contrib/contenttypes/test_contenttypes_iterate.py +++ b/tests/contrib/contenttypes/test_contenttypes_iterate.py @@ -40,7 +40,7 @@ class Meta: class Person(edgy.StrictModel): first_name = edgy.fields.CharField(max_length=100) last_name = edgy.fields.CharField(max_length=100) - # to defaults to ContentType + # to defaults to registry.content_type c = ContentTypeField() class Meta: diff --git a/tests/models/test_lazyness.py b/tests/metaclass/test_lazyness.py similarity index 90% rename from tests/models/test_lazyness.py rename to tests/metaclass/test_lazyness.py index 42f34213..e35808e7 100644 --- a/tests/models/test_lazyness.py +++ b/tests/metaclass/test_lazyness.py @@ -36,10 +36,16 @@ def test_control_lazyness(): assert User.meta is models.get_model("User").meta # initial assert not BaseUser.meta._fields_are_initialized - assert User.meta._fields_are_initialized + assert not BaseUser.meta._field_stats_are_initialized + assert not User.meta._fields_are_initialized + assert User.meta._field_stats_are_initialized assert "name" not in User.meta.columns_to_field.data - assert Product.meta._fields_are_initialized + # lazy init + assert User.meta._fields_are_initialized + assert not Product.meta._fields_are_initialized assert "rating" not in Product.meta.columns_to_field.data + # lazy init + assert Product.meta._fields_are_initialized # init pk stuff assert "id" not in Product.meta.columns_to_field.data diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 5c5d0511..c6bfba0a 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -45,6 +45,8 @@ def test_migrate_without_model_apps(): assert len(models.models) == 3 assert len(migrate.registry.models) == 3 + registry = migrate.get_registry_copy() + assert len(registry.models) == 3 @pytest.mark.parametrize( @@ -59,9 +61,10 @@ def test_migrate_with_fake_model_apps(model_apps): assert len(nother.models) == 0 migrate = Migrate(app=app, registry=nother, model_apps=model_apps) + registry = migrate.get_registry_copy() assert len(nother.models) == 2 - assert len(migrate.registry.models) == 2 + assert len(registry.models) == 2 @pytest.mark.parametrize(