-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor using context managers classes for resources (#115)
- Loading branch information
Showing
12 changed files
with
258 additions
and
378 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import datetime | ||
import logging | ||
import types | ||
import typing | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]: | ||
logger.debug("Async resource initiated") | ||
try: | ||
yield datetime.datetime.now(tz=datetime.timezone.utc) | ||
finally: | ||
logger.debug("Async resource destructed") | ||
|
||
|
||
def create_sync_resource() -> typing.Iterator[datetime.datetime]: | ||
logger.debug("Resource initiated") | ||
try: | ||
yield datetime.datetime.now(tz=datetime.timezone.utc) | ||
finally: | ||
logger.debug("Resource destructed") | ||
|
||
|
||
class ContextManagerResource(typing.ContextManager[datetime.datetime]): | ||
def __enter__(self) -> datetime.datetime: | ||
logger.debug("Resource initiated") | ||
return datetime.datetime.now(tz=datetime.timezone.utc) | ||
|
||
def __exit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_value: BaseException | None, | ||
traceback: types.TracebackType | None, | ||
) -> None: | ||
logger.debug("Resource destructed") | ||
|
||
|
||
class AsyncContextManagerResource(typing.AsyncContextManager[datetime.datetime]): | ||
async def __aenter__(self) -> datetime.datetime: | ||
logger.debug("Async resource initiated") | ||
return datetime.datetime.now(tz=datetime.timezone.utc) | ||
|
||
async def __aexit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_value: BaseException | None, | ||
traceback: types.TracebackType | None, | ||
) -> None: | ||
logger.debug("Async resource destructed") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,174 +1,109 @@ | ||
import asyncio | ||
import typing | ||
from contextlib import asynccontextmanager, contextmanager | ||
|
||
import pytest | ||
|
||
from that_depends import providers | ||
from tests.creators import ( | ||
AsyncContextManagerResource, | ||
ContextManagerResource, | ||
create_async_resource, | ||
create_sync_resource, | ||
) | ||
from that_depends import BaseContainer, providers | ||
|
||
|
||
_VALUE = 42 | ||
class DIContainer(BaseContainer): | ||
async_resource = providers.Resource(create_async_resource) | ||
sync_resource = providers.Resource(create_sync_resource) | ||
async_resource_from_class = providers.Resource(AsyncContextManagerResource) | ||
sync_resource_from_class = providers.Resource(ContextManagerResource) | ||
|
||
|
||
async def _switch_routines() -> None: | ||
await asyncio.sleep(0.0) | ||
@pytest.fixture(autouse=True) | ||
async def _tear_down() -> typing.AsyncIterator[None]: | ||
try: | ||
yield | ||
finally: | ||
await DIContainer.tear_down() | ||
|
||
|
||
class SimpleCM(typing.AsyncContextManager[int]): | ||
async def __aenter__(self) -> int: | ||
await _switch_routines() | ||
return _VALUE | ||
async def test_async_resource() -> None: | ||
async_resource1 = await DIContainer.async_resource.async_resolve() | ||
async_resource2 = DIContainer.async_resource.sync_resolve() | ||
assert async_resource1 is async_resource2 | ||
|
||
async def __aexit__(self, exc_type: object, exc_value: object, traceback: object, /) -> bool | None: | ||
await _switch_routines() | ||
return None | ||
|
||
async def test_async_resource_from_class() -> None: | ||
async_resource1 = await DIContainer.async_resource_from_class.async_resolve() | ||
async_resource2 = DIContainer.async_resource_from_class.sync_resolve() | ||
assert async_resource1 is async_resource2 | ||
|
||
class SimpleCMSync(typing.ContextManager[int]): | ||
def __enter__(self) -> int: | ||
return _VALUE | ||
|
||
def __exit__(self, exc_type: object, exc_value: object, traceback: object, /) -> bool | None: | ||
return None | ||
async def test_sync_resource() -> None: | ||
sync_resource1 = await DIContainer.sync_resource.async_resolve() | ||
sync_resource2 = await DIContainer.sync_resource.async_resolve() | ||
assert sync_resource1 is sync_resource2 | ||
|
||
|
||
@asynccontextmanager | ||
async def do_stuff_cm() -> typing.AsyncIterator[int]: | ||
await _switch_routines() | ||
yield _VALUE | ||
await _switch_routines() | ||
async def test_sync_resource_from_class() -> None: | ||
sync_resource1 = await DIContainer.sync_resource_from_class.async_resolve() | ||
sync_resource2 = await DIContainer.sync_resource_from_class.async_resolve() | ||
assert sync_resource1 is sync_resource2 | ||
|
||
|
||
@contextmanager | ||
def do_stuff_cm_sync() -> typing.Iterator[int]: | ||
yield _VALUE | ||
async def test_async_resource_overridden() -> None: | ||
async_resource1 = await DIContainer.sync_resource.async_resolve() | ||
|
||
DIContainer.sync_resource.override("override") | ||
|
||
async def do_stuff_it() -> typing.AsyncIterator[int]: | ||
await _switch_routines() | ||
yield _VALUE | ||
await _switch_routines() | ||
async_resource2 = DIContainer.sync_resource.sync_resolve() | ||
async_resource3 = await DIContainer.sync_resource.async_resolve() | ||
|
||
DIContainer.sync_resource.reset_override() | ||
|
||
def do_stuff_it_sync() -> typing.Iterator[int]: | ||
yield _VALUE | ||
async_resource4 = DIContainer.sync_resource.sync_resolve() | ||
|
||
assert async_resource2 is not async_resource1 | ||
assert async_resource2 is async_resource3 | ||
assert async_resource4 is async_resource1 | ||
|
||
@pytest.mark.parametrize( | ||
"resource", | ||
[ | ||
pytest.param(providers.Resource(SimpleCM()), id="cm_simple"), | ||
pytest.param(providers.Resource(SimpleCMSync()), id="cm_simple_sync"), | ||
pytest.param(providers.Resource(do_stuff_cm), id="cm_factory"), | ||
pytest.param(providers.Resource(do_stuff_cm_sync), id="cm_sync_factory"), | ||
pytest.param(providers.Resource(do_stuff_it), id="cm_iterator"), | ||
pytest.param(providers.Resource(do_stuff_it_sync), id="cm_sync_iterator"), | ||
], | ||
) | ||
async def test_resource_async_resolve_works(resource: providers.Resource[int]) -> None: | ||
instance = await resource.async_resolve() | ||
assert instance == _VALUE | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"resource", | ||
[ | ||
pytest.param(providers.Resource(SimpleCMSync()), id="cm_simple_sync"), | ||
pytest.param(providers.Resource(do_stuff_cm_sync), id="cm_sync_factory"), | ||
pytest.param(providers.Resource(do_stuff_it_sync), id="cm_sync_iterator"), | ||
], | ||
) | ||
def test_resource_sync_resolve_works(resource: providers.Resource[int]) -> None: | ||
instance = resource.sync_resolve() | ||
assert instance == _VALUE | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"resource", | ||
[ | ||
pytest.param(providers.Resource(SimpleCM()), id="cm_simple"), | ||
pytest.param(providers.Resource(do_stuff_cm), id="cm_factory"), | ||
pytest.param(providers.Resource(do_stuff_it), id="cm_iterator"), | ||
], | ||
) | ||
def test_resource_sync_resolve_is_not_possible_for_async_context_manager(resource: providers.Resource[int]) -> None: | ||
with pytest.raises(TypeError, match="A ContextManager type was expected in synchronous resolve"): | ||
resource.sync_resolve() | ||
|
||
|
||
async def do_invalid_creator_stuff_simple_coro_func() -> None: | ||
pass | ||
|
||
|
||
async def do_invalid_creator_stuff_inner_func() -> typing.Callable[[], typing.Awaitable[int]]: | ||
async def do_stuff_inner() -> int: | ||
return 42 | ||
|
||
return do_stuff_inner | ||
|
||
|
||
# NOTE: this is a special case for resource creator normalizer, it has to be invalid, because return type annotation is | ||
# not specified here. | ||
@asynccontextmanager | ||
async def do_invalid_creator_stuff_cm_without_annotation(): # type: ignore[no-untyped-def] # noqa: ANN201 | ||
await _switch_routines() | ||
yield _VALUE | ||
await _switch_routines() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("creator", "args", "kwargs", "error_msg"), | ||
[ | ||
pytest.param( | ||
42, | ||
(), | ||
{}, | ||
"Creator is not of a valid type", | ||
id="int", | ||
), | ||
pytest.param( | ||
do_invalid_creator_stuff_simple_coro_func, | ||
(), | ||
{}, | ||
"Creator is not of a valid type", | ||
id="simple coroutine func", | ||
), | ||
pytest.param( | ||
do_invalid_creator_stuff_inner_func, | ||
(), | ||
{}, | ||
"Creator is not of a valid type", | ||
id="inner coroutine func", | ||
), | ||
pytest.param( | ||
do_invalid_creator_stuff_cm_without_annotation, | ||
(), | ||
{}, | ||
"Creator is not of a valid type", | ||
id="cm without annotation", | ||
), | ||
pytest.param( | ||
SimpleCM(), | ||
(), | ||
{"param": "not acceptable for CM"}, | ||
"AsyncContextManager does not accept any arguments", | ||
id="CM with param", | ||
), | ||
pytest.param( | ||
SimpleCMSync(), | ||
(), | ||
{"param": "not acceptable for CM"}, | ||
"ContextManager does not accept any arguments", | ||
id="sync CM with param", | ||
), | ||
], | ||
) | ||
async def test_resource_init_raises_type_error_on_invalid_arguments( | ||
# NOTE: testing inappropriate types here, so using Any in annotations. | ||
creator: typing.Any, # noqa: ANN401 | ||
args: typing.Sequence[object], | ||
kwargs: typing.Mapping[str, object], | ||
error_msg: str, | ||
) -> None: | ||
with pytest.raises(TypeError, match=error_msg): | ||
providers.Resource(creator, *args, **kwargs) | ||
|
||
async def test_sync_resource_overridden() -> None: | ||
sync_resource1 = await DIContainer.sync_resource.async_resolve() | ||
|
||
DIContainer.sync_resource.override("override") | ||
|
||
sync_resource2 = DIContainer.sync_resource.sync_resolve() | ||
sync_resource3 = await DIContainer.sync_resource.async_resolve() | ||
|
||
DIContainer.sync_resource.reset_override() | ||
|
||
sync_resource4 = DIContainer.sync_resource.sync_resolve() | ||
|
||
assert sync_resource2 is not sync_resource1 | ||
assert sync_resource2 is sync_resource3 | ||
assert sync_resource4 is sync_resource1 | ||
|
||
|
||
async def test_async_resource_race_condition() -> None: | ||
calls: int = 0 | ||
|
||
async def create_resource() -> typing.AsyncIterator[str]: | ||
nonlocal calls | ||
calls += 1 | ||
await asyncio.sleep(0) | ||
yield "" | ||
|
||
resource = providers.Resource(create_resource) | ||
|
||
async def resolve_resource() -> str: | ||
return await resource.async_resolve() | ||
|
||
await asyncio.gather(resolve_resource(), resolve_resource()) | ||
|
||
assert calls == 1 | ||
|
||
|
||
async def test_resource_unsupported_creator() -> None: | ||
with pytest.raises(TypeError, match="Unsupported resource type"): | ||
providers.Resource(None) # type: ignore[arg-type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.