Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Depends like behaviour #32

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from importlib.metadata import version

from ._depends import Depends, DependsType, Provider, dependency_provider
from .agent import Agent
from .dependencies import CallContext
from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError

__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__'
__all__ = (
'Agent',
'CallContext',
'ModelRetry',
'UnexpectedModelBehaviour',
'UserError',
'__version__',
'Depends',
'DependsType',
'Provider',
'dependency_provider',
)
__version__ = version('pydantic_ai')
16 changes: 16 additions & 0 deletions pydantic_ai/_depends/__init__.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we move /_depends to /depends to make it public, then only export Depends form the root __init__.py

Making this file public should also make the API documentation easier.

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Draws heavily from the fast_depends library, but with some important differences in behavior.

In particular:
* No pydantic validation is performed on inputs to/outputs from function calls
* No support for extra_dependencies
* No support for custom field types
* You can call injected functions and pass values for arguments that would have been injected.
When this happens, the dependency function is not called and the passed value is used instead.
In fast_depends, the dependency function is always called and provided arguments are ignored.
"""

from .depends import Depends, inject
from .models import Depends as DependsType
from .provider import Provider, dependency_provider

__all__ = ('Depends', 'inject', 'DependsType', 'Provider', 'dependency_provider')
118 changes: 118 additions & 0 deletions pydantic_ai/_depends/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import inspect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line docstring would be nice.

from collections.abc import Awaitable
from typing import (
Annotated,
Any,
Callable,
TypeVar,
Union,
)

from typing_extensions import (
ParamSpec,
get_args,
get_origin,
)

from .models import CallModel, Depends as DependsType
from .utils import (
get_evaluated_signature,
is_async_gen_callable,
is_coroutine_callable,
is_gen_callable,
)

P = ParamSpec('P')
T = TypeVar('T')


def build_call_model( # noqa C901
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
*,
use_cache: bool = True,
is_sync: bool | None = None,
) -> CallModel[P, T]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick docstring even if this is private.

name = getattr(call, '__name__', type(call).__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we assume elsewhere that call.__name__ is always available for callables.


is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call)
if is_sync is None:
is_sync = not is_call_async
else:
assert not (is_sync and is_call_async), f'You cannot use async dependency `{name}` at sync main'
is_call_generator = is_gen_callable(call)

signature = get_evaluated_signature(call)

class_fields: dict[str, tuple[Any, Any]] = {}
dependencies: dict[str, CallModel[..., Any]] = {}
positional_args: list[str] = []
keyword_args: list[str] = []
var_positional_arg: str | None = None
var_keyword_arg: str | None = None

for param_name, param in signature.parameters.items():
dep: DependsType | None = None

if param.annotation is inspect.Parameter.empty:
annotation = Any
else:
annotation = param.annotation

if get_origin(param.annotation) is Annotated:
annotated_args = get_args(param.annotation)
for arg in annotated_args[1:]:
if isinstance(arg, DependsType):
if dep:
raise ValueError(f'Cannot specify multiple `Depends` arguments for `{param_name}`!')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably use our UserError exception, or remove it and use plain ValueError everywhere.

dep = arg

default: Any
if param.kind == inspect.Parameter.VAR_POSITIONAL:
default = ()
var_positional_arg = param_name
elif param.kind == inspect.Parameter.VAR_KEYWORD:
default = {}
var_keyword_arg = param_name
elif param.default is inspect.Parameter.empty:
default = inspect.Parameter.empty
else:
default = param.default

if isinstance(default, DependsType):
if dep:
raise ValueError(f'Cannot use `Depends` with `Annotated` and a default value for `{param_name}`!')
dep, default = default, inspect.Parameter.empty

else:
class_fields[param_name] = (annotation, default)

if dep:
dependencies[param_name] = build_call_model(
dep.dependency,
use_cache=dep.use_cache,
is_sync=is_sync,
)

keyword_args.append(param_name)

else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param_name)
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
positional_args.append(param_name)

return CallModel(
call=call,
params=class_fields,
use_cache=use_cache,
is_async=is_call_async,
is_generator=is_call_generator,
dependencies=dependencies,
positional_args=positional_args,
keyword_args=keyword_args,
var_positional_arg=var_positional_arg,
var_keyword_arg=var_keyword_arg,
)
13 changes: 13 additions & 0 deletions pydantic_ai/_depends/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import sys
from importlib.metadata import version as get_version

__all__ = ('ExceptionGroup',)
ANYIO_V3 = get_version('anyio').startswith('3.')

if ANYIO_V3:
from anyio import ExceptionGroup as ExceptionGroup # type: ignore
else:
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup as ExceptionGroup
else:
ExceptionGroup = ExceptionGroup
207 changes: 207 additions & 0 deletions pydantic_ai/_depends/depends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from collections.abc import AsyncIterator, Iterator
from contextlib import AsyncExitStack, ExitStack
from functools import partial, wraps
from typing import (
Any,
Callable,
TypeVar,
cast,
overload,
)

from typing_extensions import ParamSpec

from .build import build_call_model
from .models import CallModel, Depends as DependsType
from .provider import Provider, dependency_provider

P = ParamSpec('P')
T = TypeVar('T')


def Depends(
dependency: Callable[P, T],
*,
use_cache: bool = True,
) -> T:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring. This will be public so needs to be fairly useful description.

result = DependsType(dependency=dependency, use_cache=use_cache)
# We lie about the return type here to get better type-checking
return result # type: ignore


@overload
def inject(
*,
dependency_overrides_provider: Provider | None = dependency_provider,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


@overload
def inject(
func: Callable[P, T],
) -> Callable[P, T]: ...


def inject(
func: Callable[P, T] | None = None,
dependency_overrides_provider: Provider | None = dependency_provider,
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same.

if func is None:

def decorator(func: Callable[P, T]) -> Callable[P, T]:
return _inject_decorator(func, dependency_overrides_provider)

return decorator

return _inject_decorator(func, dependency_overrides_provider)


def _inject_decorator(
func: Callable[P, T], dependency_overrides_provider: Provider | None = dependency_provider
) -> Callable[P, T]:
overrides: dict[Callable[..., Any], Callable[..., Any]] | None = (
dependency_overrides_provider.dependency_overrides if dependency_overrides_provider else None
)

def func_wrapper(func: Callable[P, T]) -> Callable[P, T]:
call_model = build_call_model(call=func)

if call_model.is_async:
if call_model.is_generator:
return partial(solve_async_gen, call_model, overrides) # type: ignore[assignment]

else:

@wraps(func)
async def async_injected_wrapper(*args: P.args, **kwargs: P.kwargs):
async with AsyncExitStack() as stack:
r = await call_model.asolve(
args=args,
kwargs=kwargs,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
)
return r
raise AssertionError('unreachable')

return async_injected_wrapper # type: ignore #

else:
if call_model.is_generator:
return partial(solve_gen, call_model, overrides) # type: ignore[assignment]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think [assignment] isn't enforced by pyright, so instead I use pyright: ignore[<why>]


else:

@wraps(func)
def sync_injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with ExitStack() as stack:
r = call_model.solve(
args=args,
kwargs=kwargs,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
)
return r
raise AssertionError('unreachable')

return sync_injected_wrapper

return func_wrapper(func)


class solve_async_gen:
_iter: AsyncIterator[Any] | None = None

def __init__(
self,
model: 'CallModel[..., Any]',
overrides: dict[Any, Any] | None,
*args: Any,
**kwargs: Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides

def __aiter__(self) -> 'solve_async_gen':
self._iter = None
self.stack = AsyncExitStack()
return self

async def __anext__(self) -> Any:
if self._iter is None:
stack = self.stack = AsyncExitStack()
await self.stack.__aenter__()
self._iter = cast(
AsyncIterator[Any],
(
await self.call.asolve(
*self.args,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
**self.kwargs,
)
).__aiter__(),
)

try:
r = await self._iter.__anext__()
except StopAsyncIteration as e:
await self.stack.__aexit__(None, None, None)
raise e
else:
return r


class solve_gen:
_iter: Iterator[Any] | None = None

def __init__(
self,
model: 'CallModel[..., Any]',
overrides: dict[Any, Any] | None,
*args: Any,
**kwargs: Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides

def __iter__(self) -> 'solve_gen':
self._iter = None
self.stack = ExitStack()
return self

def __next__(self) -> Any:
if self._iter is None:
stack = self.stack = ExitStack()
self.stack.__enter__()
self._iter = cast(
Iterator[Any],
iter(
self.call.solve(
args=self.args,
kwargs=self.kwargs,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
)
),
)

try:
r = next(self._iter)
except StopIteration as e:
self.stack.__exit__(None, None, None)
raise e
else:
return r
Empty file added pydantic_ai/_depends/model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intentional?

Empty file.
Loading
Loading