From f96c737b8c050aa7cb064d5f441630c913069a1f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 25 Oct 2024 12:01:17 +0100 Subject: [PATCH] better support for "allow_text_result" --- pydantic_ai/_result.py | 14 +++- pydantic_ai/_utils.py | 28 +++++--- pydantic_ai/models/test.py | 10 +++ tests/test_agent.py | 144 +++++++++++++++++++++++++++++++++++++ tests/typed_agent.py | 25 +++++++ 5 files changed, 210 insertions(+), 11 deletions(-) diff --git a/pydantic_ai/_result.py b/pydantic_ai/_result.py index c586e005b..ee45a93b2 100644 --- a/pydantic_ai/_result.py +++ b/pydantic_ai/_result.py @@ -96,21 +96,29 @@ def build(cls, response_type: type[ResultData], name: str, description: str) -> if response_type is str: return None + allow_text_result = False if _utils.is_model_like(response_type): type_adapter = TypeAdapter(response_type) outer_typed_dict_key: str | None = None + json_schema = _utils.check_object_json_schema(type_adapter.json_schema()) else: - # noinspection PyTypedDict + if response_type_option := _utils.extract_str_from_union(response_type): + response_type = response_type_option.value + allow_text_result = True + response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa type_adapter = TypeAdapter(response_data_typed_dict) outer_typed_dict_key = 'response' + json_schema = _utils.check_object_json_schema(type_adapter.json_schema()) + # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM + json_schema.pop('title') # pyright: ignore[reportCallIssue,reportArgumentType] return cls( name=name, description=description, type_adapter=type_adapter, - json_schema=_utils.check_object_json_schema(type_adapter.json_schema()), - allow_text_result=_utils.allow_plain_str(response_type), + json_schema=json_schema, + allow_text_result=allow_text_result, outer_typed_dict_key=outer_typed_dict_key, ) diff --git a/pydantic_ai/_utils.py b/pydantic_ai/_utils.py index a42897c77..18c8905e9 100644 --- a/pydantic_ai/_utils.py +++ b/pydantic_ai/_utils.py @@ -24,9 +24,21 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.k _UnionType = type(Union[int, str]) -def allow_plain_str(response_type: Any) -> bool: - """Check if the response type allows plain strings.""" - return isinstance(response_type, _UnionType) and any(t is str for t in get_args(response_type)) +def extract_str_from_union(response_type: Any) -> Option[Any]: + """Extract the string type from a Union, return the remaining union or remaining type.""" + if isinstance(response_type, _UnionType) and any(t is str for t in get_args(response_type)): + remain_args: list[Any] = [] + includes_str = False + for arg in get_args(response_type): + if arg is str: + includes_str = True + else: + remain_args.append(arg) + if includes_str: + if len(remain_args) == 1: + return Some(remain_args[0]) + else: + return Some(Union[tuple(remain_args)]) def is_model_like(type_: Any) -> bool: @@ -61,18 +73,18 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema: raise ValueError('Schema must be an object') -_T = TypeVar('_T') +T = TypeVar('T') @dataclass -class Some(Generic[_T]): +class Some(Generic[T]): """Analogous to Rust's `Option::Some` type.""" - value: _T + value: T -# Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None` -Option: TypeAlias = Union[Some[_T], None] +Option: TypeAlias = Union[Some[T], None] +"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`.""" Left = TypeVar('Left') diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 80444e44b..149a8b603 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -35,6 +35,10 @@ class TestModel(Model): call_retrievers: list[str] | Literal['all'] = 'all' custom_result_text: str | None = None custom_result_args: Any | None = None + # these three fields are all set by calling `agent_model` + agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = None + agent_model_allow_text_result: bool | None = None + agent_model_result_tool: AbstractToolDefinition | None = None def agent_model( self, @@ -42,6 +46,10 @@ def agent_model( allow_text_result: bool, result_tool: AbstractToolDefinition | None, ) -> AgentModel: + self.agent_model_retrievers = retrievers + self.agent_model_allow_text_result = allow_text_result + self.agent_model_result_tool = result_tool + if self.call_retrievers == 'all': retriever_calls = [(r.name, r) for r in retrievers.values()] else: @@ -59,6 +67,8 @@ def agent_model( result = _utils.Either(right={k: self.custom_result_args}) else: result = _utils.Either(right=self.custom_result_args) + elif allow_text_result: + result = _utils.Either(left=None) elif result_tool is not None: result = _utils.Either(right=None) else: diff --git a/tests/test_agent.py b/tests/test_agent.py index 4740cf343..f461d3ed5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,3 +1,5 @@ +from typing import Union + from inline_snapshot import snapshot from pydantic import BaseModel @@ -13,6 +15,7 @@ UserPrompt, ) from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.test import TestModel from tests.conftest import IsNow @@ -151,3 +154,144 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage: ), ] ) + + +def test_response_tuple(): + m = TestModel() + + agent = Agent(m, deps=None, result_type=tuple[str, str]) + assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + + result = agent.run_sync('Hello') + assert result.response == snapshot(('b', 'b')) + + assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_allow_text_result is False + + assert m.agent_model_result_tool is not None + + # to match the protocol, we just extract the attributes we care about + fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key' + agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields} + assert agent_model_result_tool == snapshot( + { + 'name': 'final_result', + 'description': 'The final response which ends this conversation', + 'json_schema': { + 'properties': { + 'response': { + 'maxItems': 2, + 'minItems': 2, + 'prefixItems': [{'type': 'string'}, {'type': 'string'}], + 'title': 'Response', + 'type': 'array', + } + }, + 'required': ['response'], + 'type': 'object', + }, + 'outer_typed_dict_key': 'response', + } + ) + + +def test_response_union_allow_str(): + m = TestModel() + agent: Agent[None, Union[str, Foo]] = Agent( + m, + result_type=Union[str, Foo], # pyright: ignore[reportArgumentType] + ) + assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + + result = agent.run_sync('Hello') + assert result.response == snapshot('{}') + + assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_allow_text_result is True + + assert m.agent_model_result_tool is not None + + # to match the protocol, we just extract the attributes we care about + fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key' + agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields} + assert agent_model_result_tool == snapshot( + { + 'name': 'final_result', + 'description': 'The final response which ends this conversation', + 'json_schema': { + '$defs': { + 'Foo': { + 'properties': { + 'a': {'title': 'A', 'type': 'integer'}, + 'b': {'title': 'B', 'type': 'string'}, + }, + 'required': ['a', 'b'], + 'title': 'Foo', + 'type': 'object', + } + }, + 'properties': {'response': {'$ref': '#/$defs/Foo'}}, + 'required': ['response'], + 'type': 'object', + }, + 'outer_typed_dict_key': 'response', + } + ) + + +class Bar(BaseModel): + b: str + + +def test_response_multiple_return_tools(): + m = TestModel() + agent: Agent[None, Union[Foo, Bar]] = Agent( + m, + result_type=Union[Foo, Bar], # pyright: ignore[reportArgumentType] + ) + + result = agent.run_sync('Hello') + assert result.response == Foo(a=1, b='b') + + assert m.agent_model_retrievers == snapshot({}) + assert m.agent_model_allow_text_result is False + + assert m.agent_model_result_tool is not None + + # to match the protocol, we just extract the attributes we care about + fields = 'name', 'description', 'json_schema', 'outer_typed_dict_key' + agent_model_result_tool = {f: getattr(m.agent_model_result_tool, f) for f in fields} + assert agent_model_result_tool == snapshot( + { + 'name': 'final_result', + 'description': 'The final response which ends this conversation', + 'json_schema': { + '$defs': { + 'Bar': { + 'properties': {'b': {'title': 'B', 'type': 'string'}}, + 'required': ['b'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'properties': { + 'a': {'title': 'A', 'type': 'integer'}, + 'b': {'title': 'B', 'type': 'string'}, + }, + 'required': ['a', 'b'], + 'title': 'Foo', + 'type': 'object', + }, + }, + 'properties': { + 'response': { + 'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}], + 'title': 'Response', + } + }, + 'required': ['response'], + 'type': 'object', + }, + 'outer_typed_dict_key': 'response', + } + ) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index a60fa2d59..84b0c01b2 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass +from typing import Union from pydantic_ai import Agent, CallContext @@ -80,3 +81,27 @@ async def ok_retriever2(ctx: CallContext[MyDeps], x: str) -> str: if never(): typed_agent2.run_sync('testing', model='openai:gpt-4o', deps=MyDeps(foo=1, bar=2)) typed_agent2.run_sync('testing', deps=123) # type: ignore[arg-type] + + +@dataclass +class Foo: + a: int + + +@dataclass +class Bar: + b: str + + +union_agent: Agent[None, Union[Foo, Bar]] = Agent( + result_type=Union[Foo, Bar], # type: ignore[arg-type] +) + + +def foo_result(response: Union[Foo, Bar]) -> str: + return str(response) + + +if never(): + result = union_agent.run_sync('testing') + foo_result(result.response)