Skip to content

Commit

Permalink
fix(sqla_factory): added an async context manager in SQLAASyncPersist…
Browse files Browse the repository at this point in the history
…ence (#630)
  • Loading branch information
nisemenov authored Jan 15, 2025
1 parent 135bbc0 commit 137bfb9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
16 changes: 9 additions & 7 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,18 @@ def __init__(self, session: AsyncSession) -> None:
self.session = session

async def save(self, data: T) -> T:
self.session.add(data)
await self.session.commit()
await self.session.refresh(data)
async with self.session as session:
session.add(data)
await session.commit()
await session.refresh(data)
return data

async def save_many(self, data: list[T]) -> list[T]:
self.session.add_all(data)
await self.session.commit()
for batch_item in data:
await self.session.refresh(batch_item)
async with self.session as session:
session.add_all(data)
await session.commit()
for batch_item in data:
await session.refresh(batch_item)
return data


Expand Down
14 changes: 9 additions & 5 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
func,
inspect,
orm,
select,
text,
types,
)
Expand Down Expand Up @@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]):
__async_session__ = session_config(session)
__model__ = AsyncModel

result = await Factory.create_async()
assert inspect(result).persistent # type: ignore[union-attr]
instance = await Factory.create_async()
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id))
assert result

batch_result = await Factory.create_batch_async(size=2)
assert len(batch_result) == 2
for batch_item in batch_result:
assert inspect(batch_item).persistent # type: ignore[union-attr]
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id))
assert result


@pytest.mark.parametrize(
Expand Down Expand Up @@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]):
test_int = Ignore()
test_bool = Ignore()

result = await Factory.create_async()
assert inspect(result).persistent # type: ignore[union-attr]
instance = await Factory.create_async()
result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id))
assert result
assert result.test_datetime is not None
assert isinstance(result.test_datetime, datetime)
assert result.test_str == "test_str"
Expand Down

0 comments on commit 137bfb9

Please sign in to comment.