-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Depends
like behaviour
#32
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,20 @@ | ||
from importlib.metadata import version | ||
|
||
from ._depends import Depends, DependsType, Provider, dependency_provider | ||
from .agent import Agent | ||
from .dependencies import CallContext | ||
from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError | ||
|
||
__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__' | ||
__all__ = ( | ||
'Agent', | ||
'CallContext', | ||
'ModelRetry', | ||
'UnexpectedModelBehaviour', | ||
'UserError', | ||
'__version__', | ||
'Depends', | ||
'DependsType', | ||
'Provider', | ||
'dependency_provider', | ||
) | ||
__version__ = version('pydantic_ai') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Draws heavily from the fast_depends library, but with some important differences in behavior. | ||
|
||
In particular: | ||
* No pydantic validation is performed on inputs to/outputs from function calls | ||
* No support for extra_dependencies | ||
* No support for custom field types | ||
* You can call injected functions and pass values for arguments that would have been injected. | ||
When this happens, the dependency function is not called and the passed value is used instead. | ||
In fast_depends, the dependency function is always called and provided arguments are ignored. | ||
""" | ||
|
||
from .depends import Depends, inject | ||
from .models import Depends as DependsType | ||
from .provider import Provider, dependency_provider | ||
|
||
__all__ = ('Depends', 'inject', 'DependsType', 'Provider', 'dependency_provider') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import inspect | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one line docstring would be nice. |
||
from collections.abc import Awaitable | ||
from typing import ( | ||
Annotated, | ||
Any, | ||
Callable, | ||
TypeVar, | ||
Union, | ||
) | ||
|
||
from typing_extensions import ( | ||
ParamSpec, | ||
get_args, | ||
get_origin, | ||
) | ||
|
||
from .models import CallModel, Depends as DependsType | ||
from .utils import ( | ||
get_evaluated_signature, | ||
is_async_gen_callable, | ||
is_coroutine_callable, | ||
is_gen_callable, | ||
) | ||
|
||
P = ParamSpec('P') | ||
T = TypeVar('T') | ||
|
||
|
||
def build_call_model( # noqa C901 | ||
call: Union[ | ||
Callable[P, T], | ||
Callable[P, Awaitable[T]], | ||
], | ||
*, | ||
use_cache: bool = True, | ||
is_sync: bool | None = None, | ||
) -> CallModel[P, T]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quick docstring even if this is private. |
||
name = getattr(call, '__name__', type(call).__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we assume elsewhere that |
||
|
||
is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call) | ||
if is_sync is None: | ||
is_sync = not is_call_async | ||
else: | ||
assert not (is_sync and is_call_async), f'You cannot use async dependency `{name}` at sync main' | ||
is_call_generator = is_gen_callable(call) | ||
|
||
signature = get_evaluated_signature(call) | ||
|
||
class_fields: dict[str, tuple[Any, Any]] = {} | ||
dependencies: dict[str, CallModel[..., Any]] = {} | ||
positional_args: list[str] = [] | ||
keyword_args: list[str] = [] | ||
var_positional_arg: str | None = None | ||
var_keyword_arg: str | None = None | ||
|
||
for param_name, param in signature.parameters.items(): | ||
dep: DependsType | None = None | ||
|
||
if param.annotation is inspect.Parameter.empty: | ||
annotation = Any | ||
else: | ||
annotation = param.annotation | ||
|
||
if get_origin(param.annotation) is Annotated: | ||
annotated_args = get_args(param.annotation) | ||
for arg in annotated_args[1:]: | ||
if isinstance(arg, DependsType): | ||
if dep: | ||
raise ValueError(f'Cannot specify multiple `Depends` arguments for `{param_name}`!') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably use our |
||
dep = arg | ||
|
||
default: Any | ||
if param.kind == inspect.Parameter.VAR_POSITIONAL: | ||
default = () | ||
var_positional_arg = param_name | ||
elif param.kind == inspect.Parameter.VAR_KEYWORD: | ||
default = {} | ||
var_keyword_arg = param_name | ||
elif param.default is inspect.Parameter.empty: | ||
default = inspect.Parameter.empty | ||
else: | ||
default = param.default | ||
|
||
if isinstance(default, DependsType): | ||
if dep: | ||
raise ValueError(f'Cannot use `Depends` with `Annotated` and a default value for `{param_name}`!') | ||
dep, default = default, inspect.Parameter.empty | ||
|
||
else: | ||
class_fields[param_name] = (annotation, default) | ||
|
||
if dep: | ||
dependencies[param_name] = build_call_model( | ||
dep.dependency, | ||
use_cache=dep.use_cache, | ||
is_sync=is_sync, | ||
) | ||
|
||
keyword_args.append(param_name) | ||
|
||
else: | ||
if param.kind is param.KEYWORD_ONLY: | ||
keyword_args.append(param_name) | ||
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): | ||
positional_args.append(param_name) | ||
|
||
return CallModel( | ||
call=call, | ||
params=class_fields, | ||
use_cache=use_cache, | ||
is_async=is_call_async, | ||
is_generator=is_call_generator, | ||
dependencies=dependencies, | ||
positional_args=positional_args, | ||
keyword_args=keyword_args, | ||
var_positional_arg=var_positional_arg, | ||
var_keyword_arg=var_keyword_arg, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import sys | ||
from importlib.metadata import version as get_version | ||
|
||
__all__ = ('ExceptionGroup',) | ||
ANYIO_V3 = get_version('anyio').startswith('3.') | ||
|
||
if ANYIO_V3: | ||
from anyio import ExceptionGroup as ExceptionGroup # type: ignore | ||
else: | ||
if sys.version_info < (3, 11): | ||
from exceptiongroup import ExceptionGroup as ExceptionGroup | ||
else: | ||
ExceptionGroup = ExceptionGroup |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
from collections.abc import AsyncIterator, Iterator | ||
from contextlib import AsyncExitStack, ExitStack | ||
from functools import partial, wraps | ||
from typing import ( | ||
Any, | ||
Callable, | ||
TypeVar, | ||
cast, | ||
overload, | ||
) | ||
|
||
from typing_extensions import ParamSpec | ||
|
||
from .build import build_call_model | ||
from .models import CallModel, Depends as DependsType | ||
from .provider import Provider, dependency_provider | ||
|
||
P = ParamSpec('P') | ||
T = TypeVar('T') | ||
|
||
|
||
def Depends( | ||
dependency: Callable[P, T], | ||
*, | ||
use_cache: bool = True, | ||
) -> T: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring. This will be public so needs to be fairly useful description. |
||
result = DependsType(dependency=dependency, use_cache=use_cache) | ||
# We lie about the return type here to get better type-checking | ||
return result # type: ignore | ||
|
||
|
||
@overload | ||
def inject( | ||
*, | ||
dependency_overrides_provider: Provider | None = dependency_provider, | ||
) -> Callable[[Callable[P, T]], Callable[P, T]]: ... | ||
|
||
|
||
@overload | ||
def inject( | ||
func: Callable[P, T], | ||
) -> Callable[P, T]: ... | ||
|
||
|
||
def inject( | ||
func: Callable[P, T] | None = None, | ||
dependency_overrides_provider: Provider | None = dependency_provider, | ||
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same. |
||
if func is None: | ||
|
||
def decorator(func: Callable[P, T]) -> Callable[P, T]: | ||
return _inject_decorator(func, dependency_overrides_provider) | ||
|
||
return decorator | ||
|
||
return _inject_decorator(func, dependency_overrides_provider) | ||
|
||
|
||
def _inject_decorator( | ||
func: Callable[P, T], dependency_overrides_provider: Provider | None = dependency_provider | ||
) -> Callable[P, T]: | ||
overrides: dict[Callable[..., Any], Callable[..., Any]] | None = ( | ||
dependency_overrides_provider.dependency_overrides if dependency_overrides_provider else None | ||
) | ||
|
||
def func_wrapper(func: Callable[P, T]) -> Callable[P, T]: | ||
call_model = build_call_model(call=func) | ||
|
||
if call_model.is_async: | ||
if call_model.is_generator: | ||
return partial(solve_async_gen, call_model, overrides) # type: ignore[assignment] | ||
|
||
else: | ||
|
||
@wraps(func) | ||
async def async_injected_wrapper(*args: P.args, **kwargs: P.kwargs): | ||
async with AsyncExitStack() as stack: | ||
r = await call_model.asolve( | ||
args=args, | ||
kwargs=kwargs, | ||
stack=stack, | ||
dependency_overrides=overrides, | ||
cache_dependencies={}, | ||
nested=False, | ||
) | ||
return r | ||
raise AssertionError('unreachable') | ||
|
||
return async_injected_wrapper # type: ignore # | ||
|
||
else: | ||
if call_model.is_generator: | ||
return partial(solve_gen, call_model, overrides) # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
|
||
else: | ||
|
||
@wraps(func) | ||
def sync_injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: | ||
with ExitStack() as stack: | ||
r = call_model.solve( | ||
args=args, | ||
kwargs=kwargs, | ||
stack=stack, | ||
dependency_overrides=overrides, | ||
cache_dependencies={}, | ||
nested=False, | ||
) | ||
return r | ||
raise AssertionError('unreachable') | ||
|
||
return sync_injected_wrapper | ||
|
||
return func_wrapper(func) | ||
|
||
|
||
class solve_async_gen: | ||
_iter: AsyncIterator[Any] | None = None | ||
|
||
def __init__( | ||
self, | ||
model: 'CallModel[..., Any]', | ||
overrides: dict[Any, Any] | None, | ||
*args: Any, | ||
**kwargs: Any, | ||
): | ||
self.call = model | ||
self.args = args | ||
self.kwargs = kwargs | ||
self.overrides = overrides | ||
|
||
def __aiter__(self) -> 'solve_async_gen': | ||
self._iter = None | ||
self.stack = AsyncExitStack() | ||
return self | ||
|
||
async def __anext__(self) -> Any: | ||
if self._iter is None: | ||
stack = self.stack = AsyncExitStack() | ||
await self.stack.__aenter__() | ||
self._iter = cast( | ||
AsyncIterator[Any], | ||
( | ||
await self.call.asolve( | ||
*self.args, | ||
stack=stack, | ||
dependency_overrides=self.overrides, | ||
cache_dependencies={}, | ||
nested=False, | ||
**self.kwargs, | ||
) | ||
).__aiter__(), | ||
) | ||
|
||
try: | ||
r = await self._iter.__anext__() | ||
except StopAsyncIteration as e: | ||
await self.stack.__aexit__(None, None, None) | ||
raise e | ||
else: | ||
return r | ||
|
||
|
||
class solve_gen: | ||
_iter: Iterator[Any] | None = None | ||
|
||
def __init__( | ||
self, | ||
model: 'CallModel[..., Any]', | ||
overrides: dict[Any, Any] | None, | ||
*args: Any, | ||
**kwargs: Any, | ||
): | ||
self.call = model | ||
self.args = args | ||
self.kwargs = kwargs | ||
self.overrides = overrides | ||
|
||
def __iter__(self) -> 'solve_gen': | ||
self._iter = None | ||
self.stack = ExitStack() | ||
return self | ||
|
||
def __next__(self) -> Any: | ||
if self._iter is None: | ||
stack = self.stack = ExitStack() | ||
self.stack.__enter__() | ||
self._iter = cast( | ||
Iterator[Any], | ||
iter( | ||
self.call.solve( | ||
args=self.args, | ||
kwargs=self.kwargs, | ||
stack=stack, | ||
dependency_overrides=self.overrides, | ||
cache_dependencies={}, | ||
nested=False, | ||
) | ||
), | ||
) | ||
|
||
try: | ||
r = next(self._iter) | ||
except StopIteration as e: | ||
self.stack.__exit__(None, None, None) | ||
raise e | ||
else: | ||
return r |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this intentional? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest we move
/_depends
to/depends
to make it public, then only exportDepends
form the root__init__.py
Making this file public should also make the API documentation easier.