Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqla_factory): added __set_association_proxy__ attribute #629

Merged
merged 7 commits into from
Jan 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class AuthorFactoryWithRelationship(SQLAlchemyFactory[Author]):
__set_relationships__ = True


def test_sqla_factory_without_relationship() -> None:
def test_sqla_factory() -> None:
author = AuthorFactory.build()
assert author.books == []


def test_sqla_factory() -> None:
def test_sqla_factory_with_relationship() -> None:
author = AuthorFactoryWithRelationship.build()
assert isinstance(author, Author)
assert isinstance(author.books[0], Book)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from sqlalchemy import ForeignKey
from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory


class Base(DeclarativeBase): ...


class User(Base):
__tablename__ = "users"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

user_keyword_associations: Mapped[list["UserKeywordAssociation"]] = relationship(
back_populates="user",
)
keywords: AssociationProxy[list["Keyword"]] = association_proxy(
"user_keyword_associations",
"keyword",
creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj),
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), primary_key=True)
keyword_id: Mapped[int] = mapped_column(ForeignKey("keywords.id"), primary_key=True)

user: Mapped[User] = relationship(back_populates="user_keyword_associations")
keyword: Mapped["Keyword"] = relationship()


class Keyword(Base):
__tablename__ = "keywords"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str]


class UserFactory(SQLAlchemyFactory[User]): ...


class UserFactoryWithAssociation(SQLAlchemyFactory[User]):
__set_association_proxy__ = True


def test_sqla_factory() -> None:
user = UserFactory.build()
assert not user.user_keyword_associations
assert not user.keywords


def test_sqla_factory_with_association() -> None:
user = UserFactoryWithAssociation.build()
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)
assert isinstance(user.keywords[0], Keyword)
32 changes: 25 additions & 7 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SQLAlchemyFactory
===================
=================

Basic usage is like other factories

Expand All @@ -10,21 +10,39 @@ Basic usage is like other factories
.. note::
The examples here require SQLAlchemy 2 to be installed. The factory itself supports both 1.4 and 2.


Configuration
------------------------------
-------------

SQLAlchemyFactory allows to override some configuration attributes so that a described factory can use a behavior from SQLAlchemy ORM such as `relationship() <https://docs.sqlalchemy.org/en/20/orm/relationship_api.html#sqlalchemy.orm.relationship>`_ or `Association Proxy <https://docs.sqlalchemy.org/en/20/orm/extensions/associationproxy.html#module-sqlalchemy.ext.associationproxy>`_.

Relationship
++++++++++++

By default, relationships will not be set. This can be overridden via ``__set_relationships__``.
By default, ``__set_relationships__`` is set to ``False``. If it is ``True``, all fields with the SQLAlchemy `relationship() <relationship()_>`_ will be included in the resulting mock dictionary created by ``build`` method.
nisemenov marked this conversation as resolved.
Show resolved Hide resolved

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_2.py
:caption: Setting relationships
:language: python

.. note::
In general, foreign keys are not automatically generated by ``.build``. This can be resolved by setting the fields yourself and/or using ``create_sync``/ ``create_async`` so models can be added to a SQLA session so these are set.
In general, ForeignKey fields are automatically generated by ``build`` method because :class:`__set_foreign_keys__ <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory.__set_foreign_keys__>` is set to ``True`` by default. But their values can be overwritten by using ``create_sync``/ ``create_async`` methods, so SQLAlchemy ORM creates them.
adhtruong marked this conversation as resolved.
Show resolved Hide resolved

Association Proxy
+++++++++++++++++

By default, ``__set_association_proxy__`` is set to ``False``. If it is ``True``, all SQLAlchemy fields mapped to ORM `Association Proxy <Association Proxy_>`_ class will be included in the resulting mock dictionary created by ``build`` method.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_association_proxy.py
:caption: Setting association_proxy
:language: python

.. note::
If ``__set_relationships__`` attribute is set to ``True``, the Polyfactory will create both fields from a particular SQLAlchemy model (association_proxy and its relationship), but eventually a relationship field will be overwritten by using ``create_sync``/ ``create_async`` methods via SQLAlchemy ORM with a proper instance from an Association Proxy relation.


Persistence
------------------------------
-----------

A handler is provided to allow persistence. This can be used by setting ``__session__`` attribute on a factory.

Expand All @@ -38,7 +56,7 @@ Similarly for ``__async_session__`` and ``create_async``.


Adding global overrides
------------------------------
-----------------------

By combining the above and using other settings, a global base factory can be set up for other factories.

Expand All @@ -48,5 +66,5 @@ By combining the above and using other settings, a global base factory can be se


API reference
------------------------------
-------------
Full API docs are available :class:`here <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory>`.
21 changes: 21 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm import InstanceState, Mapper
except ImportError as e:
msg = "sqlalchemy is not installed"
Expand Down Expand Up @@ -76,6 +77,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"""Configuration to consider columns with foreign keys as a field or not."""
__set_relationships__: ClassVar[bool] = False
"""Configuration to consider relationships property as a model field or not."""
__set_association_proxy__: ClassVar[bool] = False
"""Configuration to consider AssociationProxy property as a model field or not."""

__session__: ClassVar[Session | Callable[[], Session] | None] = None
__async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None
Expand All @@ -85,6 +88,7 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"__set_primary_key__",
"__set_foreign_keys__",
"__set_relationships__",
"__set_association_proxy__",
)

@classmethod
Expand Down Expand Up @@ -213,6 +217,23 @@ def get_model_fields(cls) -> list[FieldMeta]:
random=cls.__random__,
),
)
if cls.__set_association_proxy__:
for name, attr in table.all_orm_descriptors.items():
if isinstance(attr, AssociationProxy):
target_collection = table.relationships.get(attr.target_collection)
if target_collection:
target_class = target_collection.entity.class_
target_attr = getattr(target_class, attr.value_attr)
if target_attr:
class_ = target_attr.entity.class_
annotation = class_ if not target_collection.uselist else List[class_] # type: ignore[valid-type]
fields_meta.append(
FieldMeta.from_type(
name=name,
annotation=annotation,
random=cls.__random__,
)
)

return fields_meta

Expand Down
83 changes: 83 additions & 0 deletions tests/sqlalchemy_factory/test_association_proxy_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, __version__
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

if __version__.startswith("2"):
adhtruong marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip(allow_module_level=True)

_registry = registry()


class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata


class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True)
name = Column(String)

user_keyword_associations = relationship(
"UserKeywordAssociation",
back_populates="user",
)
keywords = association_proxy(
"user_keyword_associations", "keyword", creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj)
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"

user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
keyword_id = Column(Integer, ForeignKey("keywords.id"), primary_key=True)

user = relationship(User, back_populates="user_keyword_associations")
keyword = relationship("Keyword")

# for prevent mypy error: Unexpected keyword argument "keyword" for "UserKeywordAssociation" [call-arg]
def __init__(self, keyword: Optional["Keyword"] = None):
self.keyword = keyword


class Keyword(Base):
__tablename__ = "keywords"

id = Column(Integer, primary_key=True)
keyword = Column(String)


def test_association_proxy() -> None:
class UserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

user = UserFactory.build()
assert isinstance(user.keywords[0], Keyword)
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)


async def test_complex_association_proxy() -> None:
class KeywordFactory(SQLAlchemyFactory[Keyword]): ...

class ComplexUserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

keywords = KeywordFactory.batch(3)

user = ComplexUserFactory.build()
assert isinstance(user, User)
assert isinstance(user.keywords[0], Keyword)
assert len(user.keywords) == 3
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)
assert len(user.user_keyword_associations) == 3
71 changes: 71 additions & 0 deletions tests/sqlalchemy_factory/test_association_proxy_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List

import pytest
from sqlalchemy import ForeignKey, __version__, orm
from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
from sqlalchemy.orm import Mapped, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

if __version__.startswith("1"):
pytest.skip(allow_module_level=True)


class Base(orm.DeclarativeBase):
pass


class User(Base):
__tablename__ = "users"

id: Mapped[int] = orm.mapped_column(primary_key=True)
name: Mapped[str]

user_keyword_associations: Mapped[List["UserKeywordAssociation"]] = relationship(
back_populates="user",
)
keywords: AssociationProxy[List["Keyword"]] = association_proxy(
"user_keyword_associations",
"keyword",
creator=lambda keyword_obj: UserKeywordAssociation(keyword=keyword_obj),
)


class UserKeywordAssociation(Base):
__tablename__ = "user_keyword"
user_id: Mapped[int] = orm.mapped_column(ForeignKey("users.id"), primary_key=True)
keyword_id: Mapped[int] = orm.mapped_column(ForeignKey("keywords.id"), primary_key=True)

user: Mapped[User] = relationship(back_populates="user_keyword_associations")
keyword: Mapped["Keyword"] = relationship()


class Keyword(Base):
__tablename__ = "keywords"
id: Mapped[int] = orm.mapped_column(primary_key=True)
keyword: Mapped[str]


def test_association_proxy() -> None:
class UserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

user = UserFactory.build()
assert isinstance(user.keywords[0], Keyword)
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)


async def test_complex_association_proxy() -> None:
class KeywordFactory(SQLAlchemyFactory[Keyword]): ...

class ComplexUserFactory(SQLAlchemyFactory[User]):
__set_association_proxy__ = True

keywords = KeywordFactory.batch(3)

user = ComplexUserFactory.build()
assert isinstance(user, User)
assert isinstance(user.keywords[0], Keyword)
assert len(user.keywords) == 3
assert isinstance(user.user_keyword_associations[0], UserKeywordAssociation)
assert len(user.user_keyword_associations) == 3
Loading