Skip to content

Commit

Permalink
reduce change
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 13, 2024
1 parent 61c4fbb commit c0d70ad
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@
from pydantic_core import SchemaValidator
from typing_extensions import Concatenate, ParamSpec

from . import _pydantic, _utils, call, messages
from . import _pydantic, _utils, messages
from .call import AgentDeps, CallContext, Retry

# retrieval function parameters
P = ParamSpec('P')


# Usage `RetrieverContextFunc[AgentDependencies, P]`
RetrieverContextFunc = Callable[Concatenate[call.CallContext[call.AgentDeps], P], Union[str, Awaitable[str]]]
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], P], Union[str, Awaitable[str]]]
# Usage `RetrieverPlainFunc[P]`
RetrieverPlainFunc = Callable[P, Union[str, Awaitable[str]]]
# Usage `RetrieverEitherFunc[AgentDependencies, P]`
RetrieverEitherFunc = _utils.Either[RetrieverContextFunc[call.AgentDeps, P], RetrieverPlainFunc[P]]
RetrieverEitherFunc = _utils.Either[RetrieverContextFunc[AgentDeps, P], RetrieverPlainFunc[P]]


@dataclass(init=False)
class Retriever(Generic[call.AgentDeps, P]):
class Retriever(Generic[AgentDeps, P]):
"""A retriever function for an agent."""

name: str
description: str
function: RetrieverEitherFunc[call.AgentDeps, P]
function: RetrieverEitherFunc[AgentDeps, P]
is_async: bool
takes_ctx: bool
single_arg_name: str | None
Expand All @@ -41,7 +42,7 @@ class Retriever(Generic[call.AgentDeps, P]):
max_retries: int
_current_retry: int = 0

def __init__(self, function: RetrieverEitherFunc[call.AgentDeps, P], retries: int):
def __init__(self, function: RetrieverEitherFunc[AgentDeps, P], retries: int):
"""Build a Retriever dataclass from a function."""
self.function = function
f = _pydantic.function_schema(function)
Expand All @@ -60,7 +61,7 @@ def reset(self) -> None:
"""Reset the current retry count."""
self._current_retry = 0

async def run(self, deps: call.AgentDeps, message: messages.ToolCall) -> messages.Message:
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
"""Run the retriever function asynchronously."""
try:
args_dict = self.validator.validate_json(message.arguments)
Expand All @@ -75,7 +76,7 @@ async def run(self, deps: call.AgentDeps, message: messages.ToolCall) -> message
else:
function = cast(Callable[[Any], str], self.function.whichever())
response_content = await _utils.run_in_executor(function, *args, **kwargs)
except call.Retry as e:
except Retry as e:
return self._on_error(e.message, message)

self._current_retry = 0
Expand All @@ -85,11 +86,11 @@ async def run(self, deps: call.AgentDeps, message: messages.ToolCall) -> message
tool_id=message.tool_id,
)

def _call_args(self, deps: call.AgentDeps, args_dict: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
def _call_args(self, deps: AgentDeps, args_dict: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
if self.single_arg_name:
args_dict = {self.single_arg_name: args_dict}

args = [call.CallContext(deps, self._current_retry)] if self.function.is_left() else []
args = [CallContext(deps, self._current_retry)] if self.function.is_left() else []
for positional_field in self.positional_fields:
args.append(args_dict.pop(positional_field))
if self.var_positional_field:
Expand Down

0 comments on commit c0d70ad

Please sign in to comment.