Skip to content

Commit

Permalink
Documentation fix & discussion changes. (#143)
Browse files Browse the repository at this point in the history
* Small documentation fix.

* Added pre-commit configuration to just script.

* Migrated existing documetation to google format.

* Added docstrings and method overrides

* Fixed entering async context with multiple async threads

* Refactored ContextResources to not act as a ContextManager

* Added additional documentation and fixed override imports.

* Added threading test for context resources.

* Improved docstrings for providers.
  • Loading branch information
alexanderlazarev0 authored Jan 17, 2025
1 parent b8b6a83 commit 9781e3d
Show file tree
Hide file tree
Showing 21 changed files with 778 additions and 61 deletions.
1 change: 1 addition & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ default: install lint test
install:
uv lock --upgrade
uv sync --only-dev --frozen
uv run pre-commit install --overwrite

lint:
uv run ruff format
Expand Down
1 change: 0 additions & 1 deletion docs/migration/v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,3 @@ To resolve such issues in `2.*`, consider the following suggestions:
## Further Help

If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues).
```
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ extend-exclude = [
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"D1", # allow missing docstrings
"D100", # ignore missing module docstrings.
"D105", # ignore missing docstrings in magic methods.
"S101", # allow asserts
"TCH", # ignore flake8-type-checking
"FBT", # allow boolean args
Expand All @@ -76,6 +77,7 @@ ignore = [
]
isort.lines-after-imports = 2
isort.no-lines-before = ["standard-library", "local-folder"]
per-file-ignores = { "tests/*"= ["D1", "SLF001"]}

[tool.pytest.ini_options]
addopts = "--cov=. --cov-report term-missing"
Expand Down
7 changes: 3 additions & 4 deletions tests/integrations/fastapi/test_fastapi_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
def fastapi_app(request: pytest.FixtureRequest) -> fastapi.FastAPI:
app = fastapi.FastAPI()
if request.param:
app.add_middleware(DIContextMiddleware, request.param, global_context=_GLOBAL_CONTEXT)
else:
app.add_middleware(
DIContextMiddleware,
global_context=_GLOBAL_CONTEXT,
DIContextMiddleware, request.param, global_context=_GLOBAL_CONTEXT, reset_all_containers=True
)
else:
app.add_middleware(DIContextMiddleware, global_context=_GLOBAL_CONTEXT, reset_all_containers=True)

@app.get("/")
async def read_root(
Expand Down
56 changes: 47 additions & 9 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import datetime
import logging
import threading
import time
import typing
import uuid
from contextlib import AsyncExitStack, ExitStack
Expand Down Expand Up @@ -208,7 +210,7 @@ async def test_async_injection_when_resetting_resource_specific_context(
@async_context_resource.context
@inject
async def _async_injected(val: str = Provide[async_context_resource]) -> str:
assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001
assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack)
return val

async_result = await _async_injected()
Expand All @@ -224,13 +226,13 @@ async def test_sync_injection_when_resetting_resource_specific_context(
@sync_context_resource.context
@inject
async def _async_injected(val: str = Provide[sync_context_resource]) -> str:
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack)
return val

@sync_context_resource.context
@inject
def _sync_injected(val: str = Provide[sync_context_resource]) -> str:
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack)
return val

async_result = await _async_injected()
Expand Down Expand Up @@ -290,7 +292,7 @@ async def test_async_injection_when_explicitly_resetting_resource_specific_conte
@async_context_resource.async_context()
@inject
async def _async_injected(val: str = Provide[async_context_resource]) -> str:
assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001
assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack)
return val

async_result = await _async_injected()
Expand All @@ -306,13 +308,13 @@ async def test_sync_injection_when_explicitly_resetting_resource_specific_contex
@sync_context_resource.async_context()
@inject
async def _async_injected(val: str = Provide[sync_context_resource]) -> str:
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack)
return val

@sync_context_resource.sync_context()
@inject
def _sync_injected(val: str = Provide[sync_context_resource]) -> str:
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001
assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack)
return val

async_result = await _async_injected()
Expand Down Expand Up @@ -570,19 +572,19 @@ def test_enter_sync_context_for_async_resource_should_throw(
async_context_resource: providers.ContextResource[str],
) -> None:
with pytest.raises(RuntimeError):
async_context_resource.__enter__()
async_context_resource._enter_sync_context()


def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None:
with pytest.raises(RuntimeError):
sync_context_resource.__exit__(None, None, None)
sync_context_resource._exit_sync_context()


async def test_exit_async_context_before_enter_should_throw(
async_context_resource: providers.ContextResource[str],
) -> None:
with pytest.raises(RuntimeError):
await async_context_resource.__aexit__(None, None, None)
await async_context_resource._exit_async_context()


def test_enter_sync_context_from_async_resource_should_throw(
Expand All @@ -608,3 +610,39 @@ async def test_preserve_globals_and_initial_context() -> None:
assert fetch_context_item(key) == item
for key in new_context:
assert fetch_context_item(key) is None


async def test_async_context_switching_with_asyncio() -> None:
async def slow_async_creator() -> typing.AsyncIterator[str]:
await asyncio.sleep(0.1)
yield str(uuid.uuid4())

class MyContainer(BaseContainer):
slow_provider = providers.ContextResource(slow_async_creator)

async def _injected() -> str:
async with MyContainer.slow_provider.async_context():
return await MyContainer.slow_provider.async_resolve()

await asyncio.gather(*[_injected() for _ in range(10)])


def test_sync_context_switching_with_threads() -> None:
def slow_sync_creator() -> typing.Iterator[str]:
time.sleep(0.1)
yield str(uuid.uuid4())

class MyContainer(BaseContainer):
slow_provider = providers.ContextResource(slow_sync_creator)

def _injected() -> str:
with MyContainer.slow_provider.sync_context():
return MyContainer.slow_provider.sync_resolve()

threads = [threading.Thread(target=_injected) for _ in range(10)]

for thread in threads:
thread.start()

for thread in threads:
thread.join()
8 changes: 4 additions & 4 deletions tests/test_multiple_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ async def test_included_container() -> None:
assert all(isinstance(x, datetime.datetime) for x in sequence)

await OuterContainer.tear_down()
assert InnerContainer.sync_resource._context.instance is None # noqa: SLF001
assert InnerContainer.async_resource._context.instance is None # noqa: SLF001
assert InnerContainer.sync_resource._context.instance is None
assert InnerContainer.async_resource._context.instance is None

await OuterContainer.init_resources()
sync_resource_context = InnerContainer.sync_resource._context # noqa: SLF001
sync_resource_context = InnerContainer.sync_resource._context
assert sync_resource_context
assert sync_resource_context.instance is not None
async_resource_context = InnerContainer.async_resource._context # noqa: SLF001
async_resource_context = InnerContainer.async_resource._context
assert async_resource_context
assert async_resource_context.instance is not None
await OuterContainer.tear_down()
2 changes: 2 additions & 0 deletions that_depends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Dependency injection framework for Python."""

from that_depends import providers
from that_depends.container import BaseContainer
from that_depends.injection import Provide, inject
Expand Down
36 changes: 35 additions & 1 deletion that_depends/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import typing
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager

from typing_extensions import override

from that_depends.meta import BaseContainerMeta
from that_depends.providers import AbstractProvider, Resource, Singleton
from that_depends.providers.context_resources import ContextResource, SupportsContext
Expand All @@ -16,19 +18,24 @@


class BaseContainer(SupportsContext[None], metaclass=BaseContainerMeta):
"""Base container class."""

providers: dict[str, AbstractProvider[typing.Any]]
containers: list[type["BaseContainer"]]

def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401
@override
def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self":
msg = f"{cls.__name__} should not be instantiated"
raise RuntimeError(msg)

@classmethod
@override
def supports_sync_context(cls) -> bool:
return True

@classmethod
@contextmanager
@override
def sync_context(cls) -> typing.Iterator[None]:
with ExitStack() as stack:
for container in cls.get_containers():
Expand All @@ -40,6 +47,7 @@ def sync_context(cls) -> typing.Iterator[None]:

@classmethod
@asynccontextmanager
@override
async def async_context(cls) -> typing.AsyncIterator[None]:
async with AsyncExitStack() as stack:
for container in cls.get_containers():
Expand All @@ -50,6 +58,7 @@ async def async_context(cls) -> typing.AsyncIterator[None]:
yield

@classmethod
@override
def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]:
if inspect.iscoroutinefunction(func):

Expand Down Expand Up @@ -79,19 +88,22 @@ def connect_containers(cls, *containers: type["BaseContainer"]) -> None:

@classmethod
def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]:
"""Get all connected providers."""
if not hasattr(cls, "providers"):
cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)}
return cls.providers

@classmethod
def get_containers(cls) -> list[type["BaseContainer"]]:
"""Get all connected containers."""
if not hasattr(cls, "containers"):
cls.containers = []

return cls.containers

@classmethod
async def init_resources(cls) -> None:
"""Initialize all resources."""
for provider in cls.get_providers().values():
if isinstance(provider, Resource):
await provider.async_resolve()
Expand All @@ -101,6 +113,7 @@ async def init_resources(cls) -> None:

@classmethod
async def tear_down(cls) -> None:
"""Tear down all resources."""
for provider in reversed(cls.get_providers().values()):
if isinstance(provider, Resource | Singleton):
await provider.tear_down()
Expand All @@ -110,18 +123,30 @@ async def tear_down(cls) -> None:

@classmethod
def reset_override(cls) -> None:
"""Reset all provider overrides."""
for v in cls.get_providers().values():
v.reset_override()

@classmethod
def resolver(cls, item: typing.Callable[P, T]) -> typing.Callable[[], typing.Awaitable[T]]:
"""Decorate a function to automatically resolve dependencies on call by name.
Args:
item: objects for which the dependencies should be resolved.
Returns:
Async wrapped callable with auto-injected dependencies.
"""

async def _inner() -> T:
return await cls.resolve(item)

return _inner

@classmethod
async def resolve(cls, object_to_resolve: typing.Callable[..., T]) -> T:
"""Inject dependencies into an object automatically by name."""
signature: typing.Final = inspect.signature(object_to_resolve)
kwargs = {}
providers: typing.Final = cls.get_providers()
Expand All @@ -140,6 +165,15 @@ async def resolve(cls, object_to_resolve: typing.Callable[..., T]) -> T:
@classmethod
@contextmanager
def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]:
"""Override several providers with mocks simultaneously.
Args:
providers_for_overriding: {provider_name: mock} dictionary.
Returns:
None
"""
current_providers: typing.Final = cls.get_providers()
current_provider_names: typing.Final = set(current_providers.keys())
given_provider_names: typing.Final = set(providers_for_overriding.keys())
Expand Down
1 change: 1 addition & 0 deletions that_depends/entities/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Entities."""
12 changes: 10 additions & 2 deletions that_depends/entities/resource_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@


class ResourceContext(typing.Generic[T_co]):
"""Class to manage a resources' context."""

__slots__ = "asyncio_lock", "context_stack", "instance", "is_async", "threading_lock"

def __init__(self, is_async: bool) -> None:
"""Create a new ResourceContext instance.
:param is_async: Whether the ResourceContext was created in an async context.
Args:
is_async (bool): Whether the ResourceContext was created in
an async context.
For example within a ``async with container_context(): ...`` statement.
:type is_async: bool
"""
self.instance: T_co | None = None
self.asyncio_lock: typing.Final = asyncio.Lock()
Expand All @@ -27,15 +31,18 @@ def __init__(self, is_async: bool) -> None:
def is_context_stack_async(
context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None,
) -> typing.TypeGuard[contextlib.AsyncExitStack]:
"""Check if the context stack is an async context stack."""
return isinstance(context_stack, contextlib.AsyncExitStack)

@staticmethod
def is_context_stack_sync(
context_stack: contextlib.AsyncExitStack | contextlib.ExitStack,
) -> typing.TypeGuard[contextlib.ExitStack]:
"""Check if the context stack is a sync context stack."""
return isinstance(context_stack, contextlib.ExitStack)

async def tear_down(self) -> None:
"""Tear down the async context stack."""
if self.context_stack is None:
return

Expand All @@ -47,6 +54,7 @@ async def tear_down(self) -> None:
self.instance = None

def sync_tear_down(self) -> None:
"""Tear down the sync context stack."""
if self.context_stack is None:
return

Expand Down
Loading

0 comments on commit 9781e3d

Please sign in to comment.