Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqla_factory): added __set_association_proxy__ attribute #629

Merged
merged 7 commits into from
Jan 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class AuthorFactoryWithRelationship(SQLAlchemyFactory[Author]):
__set_relationships__ = True


def test_sqla_factory_without_relationship() -> None:
def test_sqla_factory() -> None:
author = AuthorFactory.build()
assert author.books == []


def test_sqla_factory() -> None:
def test_sqla_factory_with_relationship() -> None:
author = AuthorFactoryWithRelationship.build()
assert isinstance(author, Author)
assert isinstance(author.books[0], Book)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from sqlalchemy import ForeignKey
from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory


class Base(DeclarativeBase): ...


class User(Base):
__tablename__ = "users"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

user_keyword_associations: Mapped[list["UserKeywordAssociation"]] = relationship(
back_populates="user",
)
keywords: AssociationProxy[list["Keyword"]] = association_proxy(
"user_keyword_associations",
"keyword",
creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj),
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), primary_key=True)
keyword_id: Mapped[int] = mapped_column(ForeignKey("keywords.id"), primary_key=True)

user: Mapped[User] = relationship(back_populates="user_keyword_associations")
keyword: Mapped["Keyword"] = relationship()


class Keyword(Base):
__tablename__ = "keywords"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str]


class UserFactory(SQLAlchemyFactory[User]): ...


class UserFactoryWithAssociation(SQLAlchemyFactory[User]):
__set_association_proxy__ = True


def test_sqla_factory() -> None:
user = UserFactory.build()
assert not user.user_keyword_associations
assert not user.keywords


def test_sqla_factory_with_association() -> None:
user = UserFactoryWithAssociation.build()
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)
assert isinstance(user.keywords[0], Keyword)
32 changes: 25 additions & 7 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SQLAlchemyFactory
===================
=================

Basic usage is like other factories

Expand All @@ -10,21 +10,39 @@ Basic usage is like other factories
.. note::
The examples here require SQLAlchemy 2 to be installed. The factory itself supports both 1.4 and 2.


Configuration
------------------------------
-------------

SQLAlchemyFactory allows to override some configuration attributes so that a described factory can use a behavior from SQLAlchemy ORM such as `relationship() <https://docs.sqlalchemy.org/en/20/orm/relationship_api.html#sqlalchemy.orm.relationship>`_ or `Association Proxy <https://docs.sqlalchemy.org/en/20/orm/extensions/associationproxy.html#module-sqlalchemy.ext.associationproxy>`_.

Relationship
++++++++++++

By default, relationships will not be set. This can be overridden via ``__set_relationships__``.
By default, ``__set_relationships__`` is set to ``False``. If it is ``True``, all fields with the SQLAlchemy `relationship() <relationship()_>`_ will be included in the result created by ``build`` method.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_2.py
:caption: Setting relationships
:language: python

.. note::
In general, foreign keys are not automatically generated by ``.build``. This can be resolved by setting the fields yourself and/or using ``create_sync``/ ``create_async`` so models can be added to a SQLA session so these are set.
If ``__set_relationships__ = True``, ForeignKey fields associated with relationship() will be automatically generated by ``build`` method because :class:`__set_foreign_keys__ <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory.__set_foreign_keys__>` is set to ``True`` by default. But their values will be overwritten by using ``create_sync``/ ``create_async`` methods, so SQLAlchemy ORM creates them.

Association Proxy
+++++++++++++++++

By default, ``__set_association_proxy__`` is set to ``False``. If it is ``True``, all SQLAlchemy fields mapped to ORM `Association Proxy <Association Proxy_>`_ class will be included in the result created by ``build`` method.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py
:caption: Setting association_proxy
:language: python

.. note::
If ``__set_relationships__ = True``, the Polyfactory will create both fields from a particular SQLAlchemy model (association_proxy and its relationship), but eventually a relationship field will be overwritten by using ``create_sync``/ ``create_async`` methods via SQLAlchemy ORM with a proper instance from an Association Proxy relation.


Persistence
------------------------------
-----------

A handler is provided to allow persistence. This can be used by setting ``__session__`` attribute on a factory.

Expand All @@ -38,7 +56,7 @@ Similarly for ``__async_session__`` and ``create_async``.


Adding global overrides
------------------------------
-----------------------

By combining the above and using other settings, a global base factory can be set up for other factories.

Expand All @@ -48,5 +66,5 @@ By combining the above and using other settings, a global base factory can be se


API reference
------------------------------
-------------
Full API docs are available :class:`here <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory>`.
21 changes: 21 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm import InstanceState, Mapper
except ImportError as e:
msg = "sqlalchemy is not installed"
Expand Down Expand Up @@ -78,6 +79,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"""Configuration to consider columns with foreign keys as a field or not."""
__set_relationships__: ClassVar[bool] = False
"""Configuration to consider relationships property as a model field or not."""
__set_association_proxy__: ClassVar[bool] = False
"""Configuration to consider AssociationProxy property as a model field or not."""

__session__: ClassVar[Session | Callable[[], Session] | None] = None
__async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None
Expand All @@ -87,6 +90,7 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"__set_primary_key__",
"__set_foreign_keys__",
"__set_relationships__",
"__set_association_proxy__",
)

@classmethod
Expand Down Expand Up @@ -215,6 +219,23 @@ def get_model_fields(cls) -> list[FieldMeta]:
random=cls.__random__,
),
)
if cls.__set_association_proxy__:
for name, attr in table.all_orm_descriptors.items():
if isinstance(attr, AssociationProxy):
target_collection = table.relationships.get(attr.target_collection)
if target_collection:
target_class = target_collection.entity.class_
target_attr = getattr(target_class, attr.value_attr)
if target_attr:
class_ = target_attr.entity.class_
annotation = class_ if not target_collection.uselist else List[class_] # type: ignore[valid-type]
fields_meta.append(
FieldMeta.from_type(
name=name,
annotation=annotation,
random=cls.__random__,
)
)

return fields_meta

Expand Down
79 changes: 79 additions & 0 deletions tests/sqlalchemy_factory/test_association_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional

from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

_registry = registry()


class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata


class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True)
name = Column(String)

user_keyword_associations = relationship(
"UserKeywordAssociation",
back_populates="user",
)
keywords = association_proxy(
"user_keyword_associations", "keyword", creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj)
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"

user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
keyword_id = Column(Integer, ForeignKey("keywords.id"), primary_key=True)

user = relationship(User, back_populates="user_keyword_associations")
keyword = relationship("Keyword")

# for prevent mypy error: Unexpected keyword argument "keyword" for "UserKeywordAssociation" [call-arg]
def __init__(self, keyword: Optional["Keyword"] = None):
self.keyword = keyword


class Keyword(Base):
__tablename__ = "keywords"

id = Column(Integer, primary_key=True)
keyword = Column(String)


def test_association_proxy() -> None:
class UserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

user = UserFactory.build()
assert isinstance(user.keywords[0], Keyword)
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)


def test_complex_association_proxy() -> None:
class KeywordFactory(SQLAlchemyFactory[Keyword]): ...

class ComplexUserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

keywords = KeywordFactory.batch(3)

user = ComplexUserFactory.build()
assert isinstance(user, User)
assert isinstance(user.keywords[0], Keyword)
assert len(user.keywords) == 3
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)
assert len(user.user_keyword_associations) == 3
Loading