Skip to content

Commit

Permalink
remove "result" module
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 20, 2024
1 parent 0975afb commit bd5ade0
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 121 deletions.
68 changes: 66 additions & 2 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from dataclasses import dataclass
from typing import Any, Callable, Generic, Union, cast

from pydantic import TypeAdapter, ValidationError
from typing_extensions import Self, TypedDict

from . import _utils, messages
from .result import ResultData
from .shared import AgentDeps, CallContext, ModelRetry
from .shared import AgentDeps, CallContext, ModelRetry, ResultData

# A function that always takes `ResultData` and returns `ResultData`,
# but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
Expand Down Expand Up @@ -73,3 +75,65 @@ class ToolRetryError(Exception):
def __init__(self, tool_retry: messages.ToolRetry):
self.tool_retry = tool_retry
super().__init__()


@dataclass
class ResultSchema(Generic[ResultData]):
"""Model the final response from an agent run.
Similar to `Retriever` but for the final result of running an agent.
"""

name: str
description: str
type_adapter: TypeAdapter[Any]
json_schema: _utils.ObjectJsonSchema
allow_text_result: bool
outer_typed_dict_key: str | None

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str) -> Self | None:
"""Build a ResultSchema dataclass from a response type."""
if response_type is str:
return None

if _utils.is_model_like(response_type):
type_adapter = TypeAdapter(response_type)
outer_typed_dict_key: str | None = None
else:
# noinspection PyTypedDict
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
type_adapter = TypeAdapter(response_data_typed_dict)
outer_typed_dict_key = 'response'

return cls(
name=name,
description=description,
type_adapter=type_adapter,
json_schema=_utils.check_object_json_schema(type_adapter.json_schema()),
allow_text_result=_utils.allow_plain_str(response_type),
outer_typed_dict_key=outer_typed_dict_key,
)

def validate(self, tool_call: messages.ToolCall) -> ResultData:
"""Validate a result message.
Returns:
Either the validated result data (left) or a retry message (right).
"""
try:
if isinstance(tool_call.args, messages.ArgsJson):
result = self.type_adapter.validate_json(tool_call.args.args_json)
else:
result = self.type_adapter.validate_python(tool_call.args.args_object)
except ValidationError as e:
m = messages.ToolRetry(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
)
raise ToolRetryError(m) from e
else:
if k := self.outer_typed_dict_key:
result = result[k]
return result
46 changes: 23 additions & 23 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from pydantic import ValidationError
from typing_extensions import assert_never

from . import _result, _retriever as _r, _system_prompt, _utils, messages as _messages, models, result, shared
from .shared import AgentDeps
from . import _result, _retriever as _r, _system_prompt, _utils, messages as _messages, models, shared
from .shared import AgentDeps, ResultData

__all__ = ('Agent',)
KnownModelName = Literal[
Expand All @@ -20,13 +20,13 @@


@dataclass(init=False)
class Agent(Generic[AgentDeps, result.ResultData]):
class Agent(Generic[AgentDeps, ResultData]):
"""Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM."""

# slots mostly for my sanity — knowing what attributes are available
model: models.Model | None
_result_tool: result.ResultSchema[result.ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, result.ResultData]]
_result_schema: _result.ResultSchema[ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
_allow_text_result: bool
_system_prompts: tuple[str, ...]
_retrievers: dict[str, _r.Retriever[AgentDeps, Any]]
Expand All @@ -39,7 +39,7 @@ class Agent(Generic[AgentDeps, result.ResultData]):
def __init__(
self,
model: models.Model | KnownModelName | None = None,
result_type: type[result.ResultData] = str,
result_type: type[ResultData] = str,
*,
system_prompt: str | Sequence[str] = (),
# type here looks odd, but it's required os you can avoid "partially unknown" type errors with `deps=None`
Expand All @@ -51,11 +51,11 @@ def __init__(
):
self.model = models.infer_model(model) if model is not None else None

self._result_tool = result.ResultSchema[result_type].build(
self._result_schema = _result.ResultSchema[result_type].build(
result_type, result_tool_name, result_tool_description
)
# if the result tool is None, or its schema allows `str`, we allow plain text results
self._allow_text_result = self._result_tool is None or self._result_tool.allow_text_result
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result

self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
Expand All @@ -73,7 +73,7 @@ async def run(
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> result.RunResult[result.ResultData]:
) -> shared.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Args:
Expand Down Expand Up @@ -103,7 +103,7 @@ async def run(

messages.append(_messages.UserPrompt(user_prompt))

agent_model = model_.agent_model(self._retrievers, self._allow_text_result, self._result_tool)
agent_model = model_.agent_model(self._retrievers, self._allow_text_result, self._result_schema)

for retriever in self._retrievers.values():
retriever.reset()
Expand All @@ -127,7 +127,7 @@ async def run(
run_span.set_attribute('full_messages', messages)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return result.RunResult(left.value, messages, cost=result.Cost(0))
return shared.RunResult(left.value, messages, cost=shared.Cost(0))
else:
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
Expand All @@ -145,7 +145,7 @@ def run_sync(
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> result.RunResult[result.ResultData]:
) -> shared.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
This is a convenience method that wraps `self.run` with `asyncio.run()`.
Expand All @@ -169,8 +169,8 @@ def system_prompt(
return func

def result_validator(
self, func: _result.ResultValidatorFunc[AgentDeps, result.ResultData]
) -> _result.ResultValidatorFunc[AgentDeps, result.ResultData]:
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData]
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
"""Decorator to register a result validator function."""
self._result_validators.append(_result.ResultValidator(func))
return func
Expand Down Expand Up @@ -226,7 +226,7 @@ def _register_retriever(
retries_ = retries if retries is not None else self._default_retries
retriever = _r.Retriever[AgentDeps, _r.P](func, retries_)

if self._result_tool and self._result_tool.name == retriever.name:
if self._result_schema and self._result_schema.name == retriever.name:
raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')

if retriever.name in self._retrievers:
Expand All @@ -237,7 +237,7 @@ def _register_retriever(

async def _handle_model_response(
self, model_response: _messages.LLMMessage, deps: AgentDeps
) -> _utils.Either[result.ResultData, list[_messages.Message]]:
) -> _utils.Either[ResultData, list[_messages.Message]]:
"""Process a single response from the model.
Returns:
Expand All @@ -246,22 +246,22 @@ async def _handle_model_response(
if model_response.role == 'llm-response':
# plain string response
if self._allow_text_result:
return _utils.Either(left=cast(result.ResultData, model_response.content))
return _utils.Either(left=cast(ResultData, model_response.content))
else:
self._incr_result_retry()
assert self._result_tool is not None
assert self._result_schema is not None
response = _messages.UserPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return _utils.Either(right=[response])
elif model_response.role == 'llm-tool-calls':
if self._result_tool is not None:
if self._result_schema is not None:
# if there's a result schema, and any of the calls match that name, return the result
# NOTE: this means we ignore any other tools called here
call = next((c for c in model_response.calls if c.tool_name == self._result_tool.name), None)
call = next((c for c in model_response.calls if c.tool_name == self._result_schema.name), None)
if call is not None:
try:
result_data = self._result_tool.validate(call)
result_data = self._result_schema.validate(call)
result_data = await self._validate_result(result_data, deps, call)
except _result.ToolRetryError as e:
self._incr_result_retry()
Expand All @@ -282,8 +282,8 @@ async def _handle_model_response(
assert_never(model_response)

async def _validate_result(
self, result_data: result.ResultData, deps: AgentDeps, tool_call: _messages.ToolCall
) -> result.ResultData:
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall
) -> ResultData:
for validator in self._result_validators:
result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
return result_data
Expand Down
95 changes: 0 additions & 95 deletions pydantic_ai/result.py

This file was deleted.

33 changes: 32 additions & 1 deletion pydantic_ai/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,40 @@
from . import messages
from .models import Model

__all__ = 'AgentDeps', 'ModelRetry', 'CallContext', 'AgentError', 'UserError', 'UnexpectedModelBehaviour'
__all__ = (
'AgentDeps',
'ResultData',
'Cost',
'RunResult',
'ModelRetry',
'CallContext',
'AgentError',
'UserError',
'UnexpectedModelBehaviour',
)

AgentDeps = TypeVar('AgentDeps')
ResultData = TypeVar('ResultData')


@dataclass
class Cost:
"""Cost of a run."""

total_cost: int


@dataclass
class RunResult(Generic[ResultData]):
"""Result of a run."""

response: ResultData
message_history: list[messages.Message]
cost: Cost

def message_history_json(self) -> str:
"""Return the history of messages as a JSON string."""
return messages.MessagesTypeAdapter.dump_json(self.message_history).decode()


@dataclass
Expand Down

0 comments on commit bd5ade0

Please sign in to comment.