Skip to content

Commit

Permalink
refactor using context managers classes for resources (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
lesnik512 authored Nov 9, 2024
1 parent 69bb400 commit 3657c23
Show file tree
Hide file tree
Showing 12 changed files with 258 additions and 378 deletions.
51 changes: 51 additions & 0 deletions tests/creators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datetime
import logging
import types
import typing


logger = logging.getLogger(__name__)


async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
logger.debug("Async resource initiated")
try:
yield datetime.datetime.now(tz=datetime.timezone.utc)
finally:
logger.debug("Async resource destructed")


def create_sync_resource() -> typing.Iterator[datetime.datetime]:
logger.debug("Resource initiated")
try:
yield datetime.datetime.now(tz=datetime.timezone.utc)
finally:
logger.debug("Resource destructed")


class ContextManagerResource(typing.ContextManager[datetime.datetime]):
def __enter__(self) -> datetime.datetime:
logger.debug("Resource initiated")
return datetime.datetime.now(tz=datetime.timezone.utc)

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
logger.debug("Resource destructed")


class AsyncContextManagerResource(typing.AsyncContextManager[datetime.datetime]):
async def __aenter__(self) -> datetime.datetime:
logger.debug("Async resource initiated")
return datetime.datetime.now(tz=datetime.timezone.utc)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
logger.debug("Async resource destructed")
2 changes: 1 addition & 1 deletion tests/providers/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_list_provider() -> None:


def test_list_failed_sync_resolve() -> None:
with pytest.raises(TypeError, match="A ContextManager type was expected"):
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.sequence.sync_resolve()


Expand Down
4 changes: 2 additions & 2 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

from that_depends import BaseContainer, Provide, fetch_context_item, inject, providers
from that_depends.entities.resource_context import ResourceContext
from that_depends.providers import container_context
from that_depends.providers.base import ResourceContext


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,7 +133,7 @@ async def test_context_resources_init_and_tear_down() -> None:


def test_context_resources_wrong_providers_init() -> None:
with pytest.raises(TypeError, match="Creator is not of a valid type"):
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.ContextResource(lambda: None) # type: ignore[arg-type,return-value]


Expand Down
2 changes: 1 addition & 1 deletion tests/providers/test_inject_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_async_provider() -> None:

async def test_sync_provider() -> None:
injected_factories = await DIContainer.injected_factories()
with pytest.raises(TypeError, match="A ContextManager type was expected"):
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
injected_factories.sync_factory()

await DIContainer.init_resources()
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/test_main_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def test_failed_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncFactory cannot be resolved synchronously"):
DIContainer.async_factory.sync_resolve()

with pytest.raises(TypeError, match="A ContextManager type was expected"):
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.async_resource.sync_resolve()


def test_wrong_providers_init() -> None:
with pytest.raises(TypeError, match="Creator is not of a valid type"):
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.Resource(lambda: None) # type: ignore[arg-type,return-value]


Expand Down
233 changes: 84 additions & 149 deletions tests/providers/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,174 +1,109 @@
import asyncio
import typing
from contextlib import asynccontextmanager, contextmanager

import pytest

from that_depends import providers
from tests.creators import (
AsyncContextManagerResource,
ContextManagerResource,
create_async_resource,
create_sync_resource,
)
from that_depends import BaseContainer, providers


_VALUE = 42
class DIContainer(BaseContainer):
async_resource = providers.Resource(create_async_resource)
sync_resource = providers.Resource(create_sync_resource)
async_resource_from_class = providers.Resource(AsyncContextManagerResource)
sync_resource_from_class = providers.Resource(ContextManagerResource)


async def _switch_routines() -> None:
await asyncio.sleep(0.0)
@pytest.fixture(autouse=True)
async def _tear_down() -> typing.AsyncIterator[None]:
try:
yield
finally:
await DIContainer.tear_down()


class SimpleCM(typing.AsyncContextManager[int]):
async def __aenter__(self) -> int:
await _switch_routines()
return _VALUE
async def test_async_resource() -> None:
async_resource1 = await DIContainer.async_resource.async_resolve()
async_resource2 = DIContainer.async_resource.sync_resolve()
assert async_resource1 is async_resource2

async def __aexit__(self, exc_type: object, exc_value: object, traceback: object, /) -> bool | None:
await _switch_routines()
return None

async def test_async_resource_from_class() -> None:
async_resource1 = await DIContainer.async_resource_from_class.async_resolve()
async_resource2 = DIContainer.async_resource_from_class.sync_resolve()
assert async_resource1 is async_resource2

class SimpleCMSync(typing.ContextManager[int]):
def __enter__(self) -> int:
return _VALUE

def __exit__(self, exc_type: object, exc_value: object, traceback: object, /) -> bool | None:
return None
async def test_sync_resource() -> None:
sync_resource1 = await DIContainer.sync_resource.async_resolve()
sync_resource2 = await DIContainer.sync_resource.async_resolve()
assert sync_resource1 is sync_resource2


@asynccontextmanager
async def do_stuff_cm() -> typing.AsyncIterator[int]:
await _switch_routines()
yield _VALUE
await _switch_routines()
async def test_sync_resource_from_class() -> None:
sync_resource1 = await DIContainer.sync_resource_from_class.async_resolve()
sync_resource2 = await DIContainer.sync_resource_from_class.async_resolve()
assert sync_resource1 is sync_resource2


@contextmanager
def do_stuff_cm_sync() -> typing.Iterator[int]:
yield _VALUE
async def test_async_resource_overridden() -> None:
async_resource1 = await DIContainer.sync_resource.async_resolve()

DIContainer.sync_resource.override("override")

async def do_stuff_it() -> typing.AsyncIterator[int]:
await _switch_routines()
yield _VALUE
await _switch_routines()
async_resource2 = DIContainer.sync_resource.sync_resolve()
async_resource3 = await DIContainer.sync_resource.async_resolve()

DIContainer.sync_resource.reset_override()

def do_stuff_it_sync() -> typing.Iterator[int]:
yield _VALUE
async_resource4 = DIContainer.sync_resource.sync_resolve()

assert async_resource2 is not async_resource1
assert async_resource2 is async_resource3
assert async_resource4 is async_resource1

@pytest.mark.parametrize(
"resource",
[
pytest.param(providers.Resource(SimpleCM()), id="cm_simple"),
pytest.param(providers.Resource(SimpleCMSync()), id="cm_simple_sync"),
pytest.param(providers.Resource(do_stuff_cm), id="cm_factory"),
pytest.param(providers.Resource(do_stuff_cm_sync), id="cm_sync_factory"),
pytest.param(providers.Resource(do_stuff_it), id="cm_iterator"),
pytest.param(providers.Resource(do_stuff_it_sync), id="cm_sync_iterator"),
],
)
async def test_resource_async_resolve_works(resource: providers.Resource[int]) -> None:
instance = await resource.async_resolve()
assert instance == _VALUE


@pytest.mark.parametrize(
"resource",
[
pytest.param(providers.Resource(SimpleCMSync()), id="cm_simple_sync"),
pytest.param(providers.Resource(do_stuff_cm_sync), id="cm_sync_factory"),
pytest.param(providers.Resource(do_stuff_it_sync), id="cm_sync_iterator"),
],
)
def test_resource_sync_resolve_works(resource: providers.Resource[int]) -> None:
instance = resource.sync_resolve()
assert instance == _VALUE


@pytest.mark.parametrize(
"resource",
[
pytest.param(providers.Resource(SimpleCM()), id="cm_simple"),
pytest.param(providers.Resource(do_stuff_cm), id="cm_factory"),
pytest.param(providers.Resource(do_stuff_it), id="cm_iterator"),
],
)
def test_resource_sync_resolve_is_not_possible_for_async_context_manager(resource: providers.Resource[int]) -> None:
with pytest.raises(TypeError, match="A ContextManager type was expected in synchronous resolve"):
resource.sync_resolve()


async def do_invalid_creator_stuff_simple_coro_func() -> None:
pass


async def do_invalid_creator_stuff_inner_func() -> typing.Callable[[], typing.Awaitable[int]]:
async def do_stuff_inner() -> int:
return 42

return do_stuff_inner


# NOTE: this is a special case for resource creator normalizer, it has to be invalid, because return type annotation is
# not specified here.
@asynccontextmanager
async def do_invalid_creator_stuff_cm_without_annotation(): # type: ignore[no-untyped-def] # noqa: ANN201
await _switch_routines()
yield _VALUE
await _switch_routines()


@pytest.mark.parametrize(
("creator", "args", "kwargs", "error_msg"),
[
pytest.param(
42,
(),
{},
"Creator is not of a valid type",
id="int",
),
pytest.param(
do_invalid_creator_stuff_simple_coro_func,
(),
{},
"Creator is not of a valid type",
id="simple coroutine func",
),
pytest.param(
do_invalid_creator_stuff_inner_func,
(),
{},
"Creator is not of a valid type",
id="inner coroutine func",
),
pytest.param(
do_invalid_creator_stuff_cm_without_annotation,
(),
{},
"Creator is not of a valid type",
id="cm without annotation",
),
pytest.param(
SimpleCM(),
(),
{"param": "not acceptable for CM"},
"AsyncContextManager does not accept any arguments",
id="CM with param",
),
pytest.param(
SimpleCMSync(),
(),
{"param": "not acceptable for CM"},
"ContextManager does not accept any arguments",
id="sync CM with param",
),
],
)
async def test_resource_init_raises_type_error_on_invalid_arguments(
# NOTE: testing inappropriate types here, so using Any in annotations.
creator: typing.Any, # noqa: ANN401
args: typing.Sequence[object],
kwargs: typing.Mapping[str, object],
error_msg: str,
) -> None:
with pytest.raises(TypeError, match=error_msg):
providers.Resource(creator, *args, **kwargs)

async def test_sync_resource_overridden() -> None:
sync_resource1 = await DIContainer.sync_resource.async_resolve()

DIContainer.sync_resource.override("override")

sync_resource2 = DIContainer.sync_resource.sync_resolve()
sync_resource3 = await DIContainer.sync_resource.async_resolve()

DIContainer.sync_resource.reset_override()

sync_resource4 = DIContainer.sync_resource.sync_resolve()

assert sync_resource2 is not sync_resource1
assert sync_resource2 is sync_resource3
assert sync_resource4 is sync_resource1


async def test_async_resource_race_condition() -> None:
calls: int = 0

async def create_resource() -> typing.AsyncIterator[str]:
nonlocal calls
calls += 1
await asyncio.sleep(0)
yield ""

resource = providers.Resource(create_resource)

async def resolve_resource() -> str:
return await resource.async_resolve()

await asyncio.gather(resolve_resource(), resolve_resource())

assert calls == 1


async def test_resource_unsupported_creator() -> None:
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.Resource(None) # type: ignore[arg-type]
4 changes: 2 additions & 2 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ async def test_init_async_resources() -> None:


def test_wrong_deprecated_providers_init() -> None:
with pytest.raises(TypeError, match="Creator is not of a valid type"):
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.AsyncContextResource(lambda: None) # type: ignore[arg-type,return-value]

with pytest.raises(TypeError, match="Creator is not of a valid type"):
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.AsyncResource(lambda: None) # type: ignore[arg-type,return-value]
Empty file.
Loading

0 comments on commit 3657c23

Please sign in to comment.