Skip to content

Commit

Permalink
Allow dict[str, Any] retriever response (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 20, 2024
1 parent e937a22 commit 0975afb
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 39 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
env:
SKIP: no-commit-to-branch

- run: make typecheck-mypy

test:
name: test on ${{ matrix.python-version }}
runs-on: ubuntu-latest
Expand Down
26 changes: 18 additions & 8 deletions demos/weather.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from devtools import debug

from pydantic_ai import Agent
Expand All @@ -15,18 +13,30 @@


@weather_agent.retriever_plain
async def get_location(location_description: str) -> str:
async def get_lat_lng(location_description: str) -> dict[str, float]:
"""
Get the latitude and longitude of a location.
Args:
location_description: A description of a location.
"""
if 'london' in location_description.lower():
lat_lng = {'lat': 51.1, 'lng': -0.1}
return {'lat': 51.1, 'lng': -0.1}
elif 'wiltshire' in location_description.lower():
lat_lng = {'lat': 51.1, 'lng': -2.11}
return {'lat': 51.1, 'lng': -2.11}
else:
lat_lng = {'lat': 0, 'lng': 0}
return json.dumps(lat_lng)
return {'lat': 0, 'lng': 0}


@weather_agent.retriever_plain
async def get_whether(lat: float, lng: float):
async def get_whether(lat: float, lng: float) -> str:
"""
Get the weather at a location.
Args:
lat: Latitude of the location.
lng: Longitude of the location.
"""
if abs(lat - 51.1) < 0.1 and abs(lng + 0.1) < 0.1:
# it always rains in London
return 'Raining'
Expand Down
22 changes: 20 additions & 2 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from _griffe.enumerations import DocstringSectionKind
from _griffe.models import Docstring, Object as GriffeObject
from pydantic import ConfigDict, TypeAdapter
from pydantic._internal import _decorators, _generate_schema, _typing_extra
from pydantic._internal._config import ConfigWrapper
from pydantic.config import ConfigDict
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema
from pydantic.plugin._schema_validator import create_schema_validator
Expand All @@ -25,7 +25,7 @@
from .shared import AgentDeps


__all__ = ('function_schema',)
__all__ = 'function_schema', 'LazyTypeAdapter'


class FunctionSchema(TypedDict):
Expand Down Expand Up @@ -233,3 +233,21 @@ def _is_call_ctx(annotation: Any) -> bool:
return annotation is CallContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
)


if TYPE_CHECKING:
LazyTypeAdapter = TypeAdapter
else:

class LazyTypeAdapter:
__slots__ = '_args', '_kwargs', '_type_adapter'

def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
self._type_adapter = None

def __getattr__(self, item):
if self._type_adapter is None:
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
return getattr(self._type_adapter, item)
5 changes: 3 additions & 2 deletions pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
P = ParamSpec('P')


RetrieverReturnValue = Union[str, Awaitable[str], dict[str, Any], Awaitable[dict[str, Any]]]
# Usage `RetrieverContextFunc[AgentDependencies, P]`
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], P], Union[str, Awaitable[str]]]
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], P], RetrieverReturnValue]
# Usage `RetrieverPlainFunc[P]`
RetrieverPlainFunc = Callable[P, Union[str, Awaitable[str]]]
RetrieverPlainFunc = Callable[P, RetrieverReturnValue]
# Usage `RetrieverEitherFunc[AgentDependencies, P]`
RetrieverEitherFunc = _utils.Either[RetrieverContextFunc[AgentDeps, P], RetrieverPlainFunc[P]]

Expand Down
23 changes: 19 additions & 4 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pydantic
import pydantic_core

from . import _pydantic


@dataclass
class SystemPrompt:
Expand All @@ -22,16 +24,29 @@ class UserPrompt:
role: Literal['user'] = 'user'


return_value_object = _pydantic.LazyTypeAdapter(dict[str, Any])


@dataclass
class ToolReturn:
tool_name: str
content: str
content: str | dict[str, Any]
tool_id: str | None = None
timestamp: datetime = field(default_factory=datetime.now)
role: Literal['tool-return'] = 'tool-return'

def llm_response(self) -> str:
return self.content
def model_response_str(self) -> str:
if isinstance(self.content, str):
return self.content
else:
content = return_value_object.validate_python(self.content)
return return_value_object.dump_json(content).decode()

def model_response_object(self) -> dict[str, Any]:
if isinstance(self.content, str):
return {'return_value': self.content}
else:
return return_value_object.validate_python(self.content)


@dataclass
Expand All @@ -42,7 +57,7 @@ class ToolRetry:
timestamp: datetime = field(default_factory=datetime.now)
role: Literal['tool-retry'] = 'tool-retry'

def llm_response(self) -> str:
def model_response(self) -> str:
if isinstance(self.content, str):
description = self.content
else:
Expand Down
18 changes: 10 additions & 8 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from typing import Annotated, Any, Literal, Union, cast

from httpx import AsyncClient as AsyncHTTPClient
from pydantic import Field, TypeAdapter
from pydantic import Field
from typing_extensions import assert_never

from .. import _utils, shared
from .. import _pydantic, _utils, shared
from ..messages import (
ArgsObject,
LLMMessage,
Expand Down Expand Up @@ -229,14 +229,12 @@ def function_call(cls, m: LLMToolCalls) -> _GeminiContent:

@classmethod
def function_return(cls, m: ToolReturn) -> _GeminiContent:
# TODO non string responses
response = {'return_value': m.llm_response()}
f_response = _GeminiFunctionResponsePart.from_response(m.tool_name, response)
f_response = _GeminiFunctionResponsePart.from_response(m.tool_name, m.model_response_object())
return _GeminiContent(role='user', parts=[f_response])

@classmethod
def function_retry(cls, m: ToolRetry) -> _GeminiContent:
response = {'call_error': m.llm_response()}
response = {'call_error': m.model_response()}
f_response = _GeminiFunctionResponsePart.from_response(m.tool_name, response)
return _GeminiContent(role='user', parts=[f_response])

Expand Down Expand Up @@ -390,8 +388,8 @@ class _GeminiPromptFeedback:
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]


_gemini_request_ta = TypeAdapter(_GeminiRequest)
_gemini_response_ta = TypeAdapter(_GeminiResponse)
_gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
_gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)


class _GeminiJsonSchema:
Expand Down Expand Up @@ -436,6 +434,10 @@ def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None:
return self._array(schema, allow_ref)

def _object(self, schema: dict[str, Any], allow_ref: bool) -> None:
ad_props = schema.pop('additionalProperties', None)
if ad_props:
raise shared.UserError('Additional properties in JSON Schema are not supported by Gemini')

if properties := schema.get('properties'): # pragma: no branch
for value in properties.values():
self._simplify(value, allow_ref)
Expand Down
13 changes: 10 additions & 3 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,19 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam:
elif message.role == 'user':
# UserPrompt ->
return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
elif message.role == 'tool-return' or message.role == 'tool-retry':
# ToolReturn or ToolRetry ->
elif message.role == 'tool-return':
# ToolReturn ->
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
content=message.llm_response(),
content=message.model_response_str(),
)
elif message.role == 'tool-retry':
# ToolRetry ->
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
content=message.model_response(),
)
elif message.role == 'llm-response':
# LLMResponse ->
Expand Down
7 changes: 4 additions & 3 deletions pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from __future__ import annotations as _annotations

import json
import re
import string
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Literal

import pydantic_core

from .. import _utils
from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall, ToolRetry, ToolReturn
from . import AbstractToolDefinition, AgentModel, Model
Expand Down Expand Up @@ -103,11 +104,11 @@ async def request(self, messages: list[Message]) -> LLMMessage:
self.step += 1
if response_text.value is None:
# build up details of retriever responses
output: dict[str, str] = {}
output: dict[str, Any] = {}
for message in messages:
if isinstance(message, ToolReturn):
output[message.tool_name] = message.content
return LLMResponse(content=json.dumps(output))
return LLMResponse(content=pydantic_core.to_json(output).decode())
else:
return LLMResponse(content=response_text.value)
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def whether_model(messages: list[Message], info: AgentInfo) -> LLMMessage: # pr
)
elif last.role == 'tool-return':
if last.tool_name == 'get_location':
return LLMToolCalls(calls=[ToolCall.from_json('get_whether', last.content)])
return LLMToolCalls(calls=[ToolCall.from_json('get_whether', last.model_response_str())])
elif last.tool_name == 'get_whether':
location_name = next(m.content for m in messages if m.role == 'user')
return LLMResponse(f'{last.content} in {location_name}')
Expand Down Expand Up @@ -312,7 +312,7 @@ def f(messages: list[Message], info: AgentInfo) -> LLMMessage:

def test_call_all():
result = agent_all.run_sync('Hello', model=TestModel())
assert result.response == snapshot('{"foo": "1", "bar": "2", "baz": "3", "qux": "4", "quz": "a"}')
assert result.response == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
assert result.message_history == snapshot(
[
SystemPrompt(content='foobar'),
Expand All @@ -332,6 +332,6 @@ def test_call_all():
ToolReturn(tool_name='baz', content='3', timestamp=IsNow()),
ToolReturn(tool_name='qux', content='4', timestamp=IsNow()),
ToolReturn(tool_name='quz', content='a', timestamp=IsNow()),
LLMResponse(content='{"foo": "1", "bar": "2", "baz": "3", "qux": "4", "quz": "a"}', timestamp=IsNow()),
LLMResponse(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow()),
]
)
6 changes: 3 additions & 3 deletions tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def ret_b(x: str) -> str:
return f'{x}-b'

result = agent.run_sync('x', model=TestModel(call_retrievers=['ret_a']))
assert result.response == snapshot('{"ret_a": "a-a"}')
assert result.response == snapshot('{"ret_a":"a-a"}')
assert calls == ['a']


Expand Down Expand Up @@ -82,7 +82,7 @@ async def my_ret(x: int) -> str:

result = agent.run_sync('Hello', model=TestModel())
assert call_count == 2
assert result.response == snapshot('{"my_ret": "2"}')
assert result.response == snapshot('{"my_ret":"2"}')
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
Expand All @@ -93,7 +93,7 @@ async def my_ret(x: int) -> str:
ToolRetry(tool_name='my_ret', content='First call failed', timestamp=IsNow()),
LLMToolCalls(calls=[ToolCall.from_object('my_ret', {'x': 1})], timestamp=IsNow()),
ToolReturn(tool_name='my_ret', content='2', timestamp=IsNow()),
LLMResponse(content='{"my_ret": "2"}', timestamp=IsNow()),
LLMResponse(content='{"my_ret":"2"}', timestamp=IsNow()),
]
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def my_ret(x: int) -> str:
return str(x + 1)

result = agent.run_sync('Hello')
assert result.response == snapshot('{"my_ret": "1"}')
assert result.response == snapshot('{"my_ret":"1"}')

summary = get_logfire_summary()
assert summary.traces == snapshot(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def takes_just_model(model: Foo) -> str:
)

result = agent.run_sync('', model=TestModel())
assert result.response == snapshot('{"takes_just_model": "0 a"}')
assert result.response == snapshot('{"takes_just_model":"0 a"}')


def test_takes_model_and_int():
Expand Down Expand Up @@ -279,4 +279,4 @@ def takes_just_model(model: Foo, z: int) -> str:
)

result = agent.run_sync('', model=TestModel())
assert result.response == snapshot('{"takes_just_model": "0 a 0"}')
assert result.response == snapshot('{"takes_just_model":"0 a 0"}')
10 changes: 10 additions & 0 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str:
return f'{x} {total}'


@typed_agent1.retriever_plain
def ok_retriever_plain(x: str) -> dict[str, str]:
return {'x': x}


@typed_agent1.retriever_context
async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str:
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
Expand All @@ -49,6 +54,11 @@ async def bad_retriever2(ctx: CallContext[int], x: str) -> str:
return f'{x} {ctx.deps}'


@typed_agent1.retriever_plain # type: ignore[arg-type]
async def bad_retriever_return(x: int) -> list[int]:
return [x]


with expect_error(ValueError):

@typed_agent1.retriever_context # type: ignore[arg-type]
Expand Down

0 comments on commit 0975afb

Please sign in to comment.