diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py index 0218bb0d..b3e00d5b 100644 --- a/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py @@ -32,12 +32,12 @@ class AuthorFactoryWithRelationship(SQLAlchemyFactory[Author]): __set_relationships__ = True -def test_sqla_factory_without_relationship() -> None: +def test_sqla_factory() -> None: author = AuthorFactory.build() assert author.books == [] -def test_sqla_factory() -> None: +def test_sqla_factory_with_relationship() -> None: author = AuthorFactoryWithRelationship.build() assert isinstance(author, Author) assert isinstance(author.books[0], Book) diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py new file mode 100644 index 00000000..6064581f --- /dev/null +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from sqlalchemy import ForeignKey +from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory + + +class Base(DeclarativeBase): ... + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + user_keyword_associations: Mapped[list["UserKeywordAssociation"]] = relationship( + back_populates="user", + ) + keywords: AssociationProxy[list["Keyword"]] = association_proxy( + "user_keyword_associations", + "keyword", + creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj), + ) + + +class UserKeywordAssociation(Base): + __tablename__ = "user_keyword" + user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), primary_key=True) + keyword_id: Mapped[int] = mapped_column(ForeignKey("keywords.id"), primary_key=True) + + user: Mapped[User] = relationship(back_populates="user_keyword_associations") + keyword: Mapped["Keyword"] = relationship() + + +class Keyword(Base): + __tablename__ = "keywords" + id: Mapped[int] = mapped_column(primary_key=True) + keyword: Mapped[str] + + +class UserFactory(SQLAlchemyFactory[User]): ... + + +class UserFactoryWithAssociation(SQLAlchemyFactory[User]): + __set_association_proxy__ = True + + +def test_sqla_factory() -> None: + user = UserFactory.build() + assert not user.user_keyword_associations + assert not user.keywords + + +def test_sqla_factory_with_association() -> None: + user = UserFactoryWithAssociation.build() + assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation) + assert isinstance(user.keywords[0], Keyword) diff --git a/docs/usage/library_factories/sqlalchemy_factory.rst b/docs/usage/library_factories/sqlalchemy_factory.rst index ee72c9e9..b7ca70eb 100644 --- a/docs/usage/library_factories/sqlalchemy_factory.rst +++ b/docs/usage/library_factories/sqlalchemy_factory.rst @@ -1,5 +1,5 @@ SQLAlchemyFactory -=================== +================= Basic usage is like other factories @@ -10,21 +10,39 @@ Basic usage is like other factories .. note:: The examples here require SQLAlchemy 2 to be installed. The factory itself supports both 1.4 and 2. + Configuration ------------------------------- +------------- + +SQLAlchemyFactory allows to override some configuration attributes so that a described factory can use a behavior from SQLAlchemy ORM such as `relationship() `_ or `Association Proxy `_. + +Relationship +++++++++++++ -By default, relationships will not be set. This can be overridden via ``__set_relationships__``. +By default, ``__set_relationships__`` is set to ``False``. If it is ``True``, all fields with the SQLAlchemy `relationship() `_ will be included in the result created by ``build`` method. .. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_2.py :caption: Setting relationships :language: python .. note:: - In general, foreign keys are not automatically generated by ``.build``. This can be resolved by setting the fields yourself and/or using ``create_sync``/ ``create_async`` so models can be added to a SQLA session so these are set. + If ``__set_relationships__ = True``, ForeignKey fields associated with relationship() will be automatically generated by ``build`` method because :class:`__set_foreign_keys__ ` is set to ``True`` by default. But their values will be overwritten by using ``create_sync``/ ``create_async`` methods, so SQLAlchemy ORM creates them. + +Association Proxy ++++++++++++++++++ + +By default, ``__set_association_proxy__`` is set to ``False``. If it is ``True``, all SQLAlchemy fields mapped to ORM `Association Proxy `_ class will be included in the result created by ``build`` method. + +.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py + :caption: Setting association_proxy + :language: python + +.. note:: + If ``__set_relationships__ = True``, the Polyfactory will create both fields from a particular SQLAlchemy model (association_proxy and its relationship), but eventually a relationship field will be overwritten by using ``create_sync``/ ``create_async`` methods via SQLAlchemy ORM with a proper instance from an Association Proxy relation. Persistence ------------------------------- +----------- A handler is provided to allow persistence. This can be used by setting ``__session__`` attribute on a factory. @@ -38,7 +56,7 @@ Similarly for ``__async_session__`` and ``create_async``. Adding global overrides ------------------------------- +----------------------- By combining the above and using other settings, a global base factory can be set up for other factories. @@ -48,5 +66,5 @@ By combining the above and using other settings, a global base factory can be se API reference ------------------------------- +------------- Full API docs are available :class:`here `. diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index f874d4d3..7ad75409 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -15,6 +15,7 @@ from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite from sqlalchemy.exc import NoInspectionAvailable + from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.orm import InstanceState, Mapper except ImportError as e: msg = "sqlalchemy is not installed" @@ -78,6 +79,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]): """Configuration to consider columns with foreign keys as a field or not.""" __set_relationships__: ClassVar[bool] = False """Configuration to consider relationships property as a model field or not.""" + __set_association_proxy__: ClassVar[bool] = False + """Configuration to consider AssociationProxy property as a model field or not.""" __session__: ClassVar[Session | Callable[[], Session] | None] = None __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None @@ -87,6 +90,7 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]): "__set_primary_key__", "__set_foreign_keys__", "__set_relationships__", + "__set_association_proxy__", ) @classmethod @@ -215,6 +219,23 @@ def get_model_fields(cls) -> list[FieldMeta]: random=cls.__random__, ), ) + if cls.__set_association_proxy__: + for name, attr in table.all_orm_descriptors.items(): + if isinstance(attr, AssociationProxy): + target_collection = table.relationships.get(attr.target_collection) + if target_collection: + target_class = target_collection.entity.class_ + target_attr = getattr(target_class, attr.value_attr) + if target_attr: + class_ = target_attr.entity.class_ + annotation = class_ if not target_collection.uselist else List[class_] # type: ignore[valid-type] + fields_meta.append( + FieldMeta.from_type( + name=name, + annotation=annotation, + random=cls.__random__, + ) + ) return fields_meta diff --git a/tests/sqlalchemy_factory/test_association_proxy.py b/tests/sqlalchemy_factory/test_association_proxy.py new file mode 100644 index 00000000..b3badba8 --- /dev/null +++ b/tests/sqlalchemy_factory/test_association_proxy.py @@ -0,0 +1,79 @@ +from typing import Optional + +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.orm import relationship +from sqlalchemy.orm.decl_api import DeclarativeMeta, registry + +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory + +_registry = registry() + + +class Base(metaclass=DeclarativeMeta): + __abstract__ = True + __allow_unmapped__ = True + + registry = _registry + metadata = _registry.metadata + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(String) + + user_keyword_associations = relationship( + "UserKeywordAssociation", + back_populates="user", + ) + keywords = association_proxy( + "user_keyword_associations", "keyword", creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj) + ) + + +class UserKeywordAssociation(Base): + __tablename__ = "user_keyword" + + user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) + keyword_id = Column(Integer, ForeignKey("keywords.id"), primary_key=True) + + user = relationship(User, back_populates="user_keyword_associations") + keyword = relationship("Keyword") + + # for prevent mypy error: Unexpected keyword argument "keyword" for "UserKeywordAssociation" [call-arg] + def __init__(self, keyword: Optional["Keyword"] = None): + self.keyword = keyword + + +class Keyword(Base): + __tablename__ = "keywords" + + id = Column(Integer, primary_key=True) + keyword = Column(String) + + +def test_association_proxy() -> None: + class UserFactory(SQLAlchemyFactory[User]): + __set_association_proxy__ = True + + user = UserFactory.build() + assert isinstance(user.keywords[0], Keyword) + assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation) + + +def test_complex_association_proxy() -> None: + class KeywordFactory(SQLAlchemyFactory[Keyword]): ... + + class ComplexUserFactory(SQLAlchemyFactory[User]): + __set_association_proxy__ = True + + keywords = KeywordFactory.batch(3) + + user = ComplexUserFactory.build() + assert isinstance(user, User) + assert isinstance(user.keywords[0], Keyword) + assert len(user.keywords) == 3 + assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation) + assert len(user.user_keyword_associations) == 3