Skip to content

Commit

Permalink
better support for allow_text_result (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 25, 2024
1 parent 73a1053 commit 7febaae
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 11 deletions.
14 changes: 11 additions & 3 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,29 @@ def build(cls, response_type: type[ResultData], name: str, description: str) ->
if response_type is str:
return None

allow_text_result = False
if _utils.is_model_like(response_type):
type_adapter = TypeAdapter(response_type)
outer_typed_dict_key: str | None = None
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
else:
# noinspection PyTypedDict
if response_type_option := _utils.extract_str_from_union(response_type):
response_type = response_type_option.value
allow_text_result = True

response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
type_adapter = TypeAdapter(response_data_typed_dict)
outer_typed_dict_key = 'response'
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
json_schema.pop('title') # pyright: ignore[reportCallIssue,reportArgumentType]

return cls(
name=name,
description=description,
type_adapter=type_adapter,
json_schema=_utils.check_object_json_schema(type_adapter.json_schema()),
allow_text_result=_utils.allow_plain_str(response_type),
json_schema=json_schema,
allow_text_result=allow_text_result,
outer_typed_dict_key=outer_typed_dict_key,
)

Expand Down
28 changes: 20 additions & 8 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,21 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.k
_UnionType = type(Union[int, str])


def allow_plain_str(response_type: Any) -> bool:
"""Check if the response type allows plain strings."""
return isinstance(response_type, _UnionType) and any(t is str for t in get_args(response_type))
def extract_str_from_union(response_type: Any) -> Option[Any]:
"""Extract the string type from a Union, return the remaining union or remaining type."""
if isinstance(response_type, _UnionType) and any(t is str for t in get_args(response_type)):
remain_args: list[Any] = []
includes_str = False
for arg in get_args(response_type):
if arg is str:
includes_str = True
else:
remain_args.append(arg)
if includes_str:
if len(remain_args) == 1:
return Some(remain_args[0])
else:
return Some(Union[tuple(remain_args)])


def is_model_like(type_: Any) -> bool:
Expand Down Expand Up @@ -61,18 +73,18 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
raise ValueError('Schema must be an object')


_T = TypeVar('_T')
T = TypeVar('T')


@dataclass
class Some(Generic[_T]):
class Some(Generic[T]):
"""Analogous to Rust's `Option::Some` type."""

value: _T
value: T


# Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`
Option: TypeAlias = Union[Some[_T], None]
Option: TypeAlias = Union[Some[T], None]
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""


Left = TypeVar('Left')
Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,21 @@ class TestModel(Model):
call_retrievers: list[str] | Literal['all'] = 'all'
custom_result_text: str | None = None
custom_result_args: Any | None = None
# these three fields are all set by calling `agent_model`
agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = None
agent_model_allow_text_result: bool | None = None
agent_model_result_tool: AbstractToolDefinition | None = None

def agent_model(
self,
retrievers: Mapping[str, AbstractToolDefinition],
allow_text_result: bool,
result_tool: AbstractToolDefinition | None,
) -> AgentModel:
self.agent_model_retrievers = retrievers
self.agent_model_allow_text_result = allow_text_result
self.agent_model_result_tool = result_tool

if self.call_retrievers == 'all':
retriever_calls = [(r.name, r) for r in retrievers.values()]
else:
Expand All @@ -59,6 +67,8 @@ def agent_model(
result = _utils.Either(right={k: self.custom_result_args})
else:
result = _utils.Either(right=self.custom_result_args)
elif allow_text_result:
result = _utils.Either(left=None)
elif result_tool is not None:
result = _utils.Either(right=None)
else:
Expand Down
144 changes: 144 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from inline_snapshot import snapshot
from pydantic import BaseModel

Expand All @@ -13,6 +15,7 @@
UserPrompt,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from tests.conftest import IsNow


Expand Down Expand Up @@ -151,3 +154,144 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage:
),
]
)


def test_response_tuple():
m = TestModel()

agent = Agent(m, deps=None, result_type=tuple[str, str])
assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]

result = agent.run_sync('Hello')
assert result.response == snapshot(('b', 'b'))

assert m.agent_model_retrievers == snapshot({})
assert m.agent_model_allow_text_result is False

assert m.agent_model_result_tool is not None

# to match the protocol, we just extract the attributes we care about
fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key'
agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields}
assert agent_model_result_tool == snapshot(
{
'name': 'final_result',
'description': 'The final response which ends this conversation',
'json_schema': {
'properties': {
'response': {
'maxItems': 2,
'minItems': 2,
'prefixItems': [{'type': 'string'}, {'type': 'string'}],
'title': 'Response',
'type': 'array',
}
},
'required': ['response'],
'type': 'object',
},
'outer_typed_dict_key': 'response',
}
)


def test_response_union_allow_str():
m = TestModel()
agent: Agent[None, Union[str, Foo]] = Agent(
m,
result_type=Union[str, Foo], # pyright: ignore[reportArgumentType]
)
assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]

result = agent.run_sync('Hello')
assert result.response == snapshot('{}')

assert m.agent_model_retrievers == snapshot({})
assert m.agent_model_allow_text_result is True

assert m.agent_model_result_tool is not None

# to match the protocol, we just extract the attributes we care about
fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key'
agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields}
assert agent_model_result_tool == snapshot(
{
'name': 'final_result',
'description': 'The final response which ends this conversation',
'json_schema': {
'$defs': {
'Foo': {
'properties': {
'a': {'title': 'A', 'type': 'integer'},
'b': {'title': 'B', 'type': 'string'},
},
'required': ['a', 'b'],
'title': 'Foo',
'type': 'object',
}
},
'properties': {'response': {'$ref': '#/$defs/Foo'}},
'required': ['response'],
'type': 'object',
},
'outer_typed_dict_key': 'response',
}
)


class Bar(BaseModel):
b: str


def test_response_multiple_return_tools():
m = TestModel()
agent: Agent[None, Union[Foo, Bar]] = Agent(
m,
result_type=Union[Foo, Bar], # pyright: ignore[reportArgumentType]
)

result = agent.run_sync('Hello')
assert result.response == Foo(a=1, b='b')

assert m.agent_model_retrievers == snapshot({})
assert m.agent_model_allow_text_result is False

assert m.agent_model_result_tool is not None

# to match the protocol, we just extract the attributes we care about
fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key'
agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields}
assert agent_model_result_tool == snapshot(
{
'name': 'final_result',
'description': 'The final response which ends this conversation',
'json_schema': {
'$defs': {
'Bar': {
'properties': {'b': {'title': 'B', 'type': 'string'}},
'required': ['b'],
'title': 'Bar',
'type': 'object',
},
'Foo': {
'properties': {
'a': {'title': 'A', 'type': 'integer'},
'b': {'title': 'B', 'type': 'string'},
},
'required': ['a', 'b'],
'title': 'Foo',
'type': 'object',
},
},
'properties': {
'response': {
'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}],
'title': 'Response',
}
},
'required': ['response'],
'type': 'object',
},
'outer_typed_dict_key': 'response',
}
)
25 changes: 25 additions & 0 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Union

from pydantic_ai import Agent, CallContext

Expand Down Expand Up @@ -80,3 +81,27 @@ async def ok_retriever2(ctx: CallContext[MyDeps], x: str) -> str:
if never():
typed_agent2.run_sync('testing', model='openai:gpt-4o', deps=MyDeps(foo=1, bar=2))
typed_agent2.run_sync('testing', deps=123) # type: ignore[arg-type]


@dataclass
class Foo:
a: int


@dataclass
class Bar:
b: str


union_agent: Agent[None, Union[Foo, Bar]] = Agent(
result_type=Union[Foo, Bar], # type: ignore[arg-type]
)


def foo_result(response: Union[Foo, Bar]) -> str:
return str(response)


if never():
result = union_agent.run_sync('testing')
foo_result(result.response)

0 comments on commit 7febaae

Please sign in to comment.