From 21cd229103f9f3fdb5df399b32db7090f63941d5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 8 Nov 2024 19:45:10 +0000 Subject: [PATCH] improve coverage --- pydantic_ai/_utils.py | 23 +++++- pydantic_ai/models/function.py | 123 ++++++++++++++++++++++++---- pydantic_ai/models/openai.py | 15 +--- pydantic_ai/models/test.py | 10 +-- tests/models/test_model_function.py | 49 ++++++++++- tests/models/test_openai.py | 12 ++- tests/test_agent.py | 22 ++++- tests/test_streaming.py | 83 ++++++++++++++++++- 8 files changed, 291 insertions(+), 46 deletions(-) diff --git a/pydantic_ai/_utils.py b/pydantic_ai/_utils.py index d870174a8..4530a7d47 100644 --- a/pydantic_ai/_utils.py +++ b/pydantic_ai/_utils.py @@ -2,7 +2,7 @@ import asyncio import time -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from dataclasses import dataclass, is_dataclass from functools import partial from types import GenericAlias @@ -183,3 +183,24 @@ async def group_by_temporal(aiter: AsyncIterator[T], soft_max_interval: float | yield buffer buffer = [] group_start_time = None + + +def add_optional(a: str | None, b: str | None) -> str | None: + """Add two optional strings.""" + if a is None: + return b + elif b is None: + return a + else: + return a + b + + +def sync_anext(iterator: Iterator[T]) -> T: + """Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted. + + Useful when iterating over a sync iterator in an async context. + """ + try: + return next(iterator) + except StopIteration as e: + raise StopAsyncIteration() from e diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index 2ba3a6aea..18b137da8 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -1,17 +1,22 @@ from __future__ import annotations as _annotations -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable +from itertools import chain +from typing import Callable, cast -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, overload -from .. import result -from ..messages import LLMMessage, Message -from . import AbstractToolDefinition, AgentModel, Model - -if TYPE_CHECKING: - from .._utils import ObjectJsonSchema +from .. import _utils, result +from ..messages import LLMMessage, LLMToolCalls, Message, ToolCall +from . import ( + AbstractToolDefinition, + AgentModel, + EitherStreamedResponse, + Model, + StreamTextResponse, + StreamToolCallResponse, +) @dataclass(frozen=True) @@ -23,22 +28,44 @@ class AgentInfo: result_tools: list[AbstractToolDefinition] | None +@dataclass +class DeltaToolCall: + name: str | None = None + args: str | None = None + + +DeltaToolCalls = dict[int, DeltaToolCall] + FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], LLMMessage] +StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Iterable[str] | Iterable[DeltaToolCalls]] @dataclass class ToolDescription: name: str description: str - json_schema: ObjectJsonSchema + json_schema: _utils.ObjectJsonSchema -@dataclass +@dataclass(init=False) class FunctionModel(Model): # NOTE: Avoid test discovery by pytest. __test__ = False - function: FunctionDef + function: FunctionDef | None = None + stream_function: StreamFunctionDef | None = None + + @overload + def __init__(self, function: FunctionDef) -> None: ... + + @overload + def __init__(self, *, stream_function: StreamFunctionDef) -> None: ... + + def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None): + if function is None and stream_function is None: + raise TypeError('Either `function` or `stream_function` must be provided') + self.function = function + self.stream_function = stream_function def agent_model( self, @@ -47,16 +74,82 @@ def agent_model( result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: result_tools = list(result_tools) if result_tools is not None else None - return FunctionAgentModel(self.function, AgentInfo(retrievers, allow_text_result, result_tools)) + return FunctionAgentModel( + self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools) + ) def name(self) -> str: - return f'function:{self.function.__name__}' + labels: list[str] = [] + if self.function is not None: + labels.append(self.function.__name__) + if self.stream_function is not None: + labels.append(f'stream-{self.stream_function.__name__}') + return f'function:{",".join(labels)}' @dataclass class FunctionAgentModel(AgentModel): - function: FunctionDef + function: FunctionDef | None + stream_function: StreamFunctionDef | None agent_info: AgentInfo async def request(self, messages: list[Message]) -> tuple[LLMMessage, result.Cost]: + assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' return self.function(messages, self.agent_info), result.Cost() + + async def request_stream(self, messages: list[Message]) -> EitherStreamedResponse: + assert ( + self.stream_function is not None + ), 'FunctionModel must receive a `stream_function` to support streamed requests' + response_data = iter(self.stream_function(messages, self.agent_info)) + try: + first = next(response_data) + except StopIteration as e: + raise ValueError('Stream function must return at least one item') from e + + if isinstance(first, str): + text_stream = cast(Iterable[str], response_data) + return FunctionStreamTextResponse(iter(chain([first], text_stream))) + else: + structured_stream = cast(Iterable[DeltaToolCalls], response_data) + # noinspection PyTypeChecker + return FunctionStreamToolCallResponse(iter(chain([first], structured_stream)), {}) + + +@dataclass +class FunctionStreamTextResponse(StreamTextResponse): + _iter: Iterator[str] + + async def __anext__(self) -> str: + return _utils.sync_anext(self._iter) + + def cost(self) -> result.Cost: + return result.Cost() + + +@dataclass +class FunctionStreamToolCallResponse(StreamToolCallResponse): + _iter: Iterator[DeltaToolCalls] + _delta_tool_calls: dict[int, DeltaToolCall] + + async def __anext__(self) -> None: + tool_call = _utils.sync_anext(self._iter) + + for key, new in tool_call.items(): + if current := self._delta_tool_calls.get(key): + current.name = _utils.add_optional(current.name, new.name) + current.args = _utils.add_optional(current.args, new.args) + else: + self._delta_tool_calls[key] = new + + def get(self) -> LLMToolCalls: + """Map tool call deltas to a `LLMToolCalls`.""" + calls: list[ToolCall] = [] + for c in self._delta_tool_calls.values(): + if c.name is not None and c.args is not None: + calls.append(ToolCall.from_json(c.name, c.args)) + + return LLMToolCalls(calls) + + def cost(self) -> result.Cost: + return result.Cost() diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 617ee6284..3515fcc6d 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -12,7 +12,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from typing_extensions import assert_never -from .. import UnexpectedModelBehaviour, result +from .. import UnexpectedModelBehaviour, _utils, result from ..messages import ( ArgsJson, LLMMessage, @@ -273,8 +273,8 @@ async def __anext__(self) -> None: if current.function is None: current.function = new.function elif new.function is not None: - current.function.name = _add_optional(current.function.name, new.function.name) - current.function.arguments = _add_optional(current.function.arguments, new.function.arguments) + current.function.name = _utils.add_optional(current.function.name, new.function.name) + current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments) else: self._delta_tool_calls[new.index] = new @@ -292,15 +292,6 @@ def cost(self) -> Cost: return self._cost -def _add_optional(a: str | None, b: str | None) -> str | None: - if a is None: - return b - elif b is None: - return a - else: - return a + b - - def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str: """Type guard that checks a `tool_id` is not None both for static typing and runtime.""" assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}' diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index 20fb1553b..3f1b0c624 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -184,10 +184,7 @@ def __post_init__(self): self._iter = iter(words) async def __anext__(self) -> str: - try: - return next(self._iter) - except StopIteration as e: - raise StopAsyncIteration() from e + return _utils.sync_anext(self._iter) def cost(self) -> Cost: return self._cost @@ -200,10 +197,7 @@ class TestStreamToolCallResponse(StreamToolCallResponse): _iter: Iterator[None] = field(default_factory=lambda: iter([None])) async def __anext__(self) -> None: - try: - return next(self._iter) - except StopIteration as e: - raise StopAsyncIteration() from e + return _utils.sync_anext(self._iter) def get(self) -> LLMToolCalls: return self._structured_response diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 36d141f78..ec02457ab 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -1,5 +1,6 @@ import json import re +from collections.abc import Iterable from dataclasses import asdict from datetime import timezone @@ -20,10 +21,13 @@ ToolReturn, UserPrompt, ) -from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.result import Cost from tests.conftest import IsNow +pytestmark = pytest.mark.anyio + def return_last(messages: list[Message], _: AgentInfo) -> LLMMessage: last = messages[-1] @@ -393,3 +397,46 @@ async def validate_result(r: Foo) -> Foo: result = agent.run_sync('') assert result.response == snapshot(Foo(x=2)) + + +def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str]: + yield 'hello ' + yield 'world' + + +async def test_stream_text(): + agent = Agent(FunctionModel(stream_function=stream_text_function), deps=None) + result = await agent.run_stream('') + assert await result.get_response() == snapshot('hello world') + assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))]) + assert result.cost() == snapshot(Cost()) + + +class Foo(BaseModel): + x: int + + +async def test_stream_structure(): + def stream_structured_function(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]: + assert agent_info.result_tools is not None + assert len(agent_info.result_tools) == 1 + name = agent_info.result_tools[0].name + yield {0: DeltaToolCall(name=name)} + yield {0: DeltaToolCall(args='{"x": ')} + yield {0: DeltaToolCall(args='1}')} + + agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=Foo) + result = await agent.run_stream('') + assert await result.get_response() == snapshot(Foo(x=1)) + assert result.cost() == snapshot(Cost()) + + +async def test_pass_neither(): + with pytest.raises(TypeError, match='Either `function` or `stream_function` must be provided'): + FunctionModel() # pyright: ignore[reportCallIssue] + + +async def test_return_empty(): + agent = Agent(FunctionModel(stream_function=lambda _, __: []), deps=None) + with pytest.raises(ValueError, match='Stream function must return at least one item'): + await agent.run_stream('') diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index bd6c5a77a..358714bf0 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -23,7 +23,7 @@ from openai.types.completion_usage import CompletionUsage, PromptTokensDetails from typing_extensions import TypedDict -from pydantic_ai import Agent, AgentError, ModelRetry +from pydantic_ai import Agent, AgentError, ModelRetry, _utils from pydantic_ai.messages import ( ArgsJson, LLMResponse, @@ -52,10 +52,7 @@ class MockAsyncStream: _iter: Iterator[chat.ChatCompletionChunk] async def __anext__(self) -> chat.ChatCompletionChunk: - try: - return next(self._iter) - except StopIteration as e: - raise StopAsyncIteration() from e + return _utils.sync_anext(self._iter) @dataclass @@ -87,13 +84,14 @@ async def chat_completions_create( self, *_args: Any, stream: bool = False, **_kwargs: Any ) -> chat.ChatCompletion | MockAsyncStream: if stream: - assert self.stream is not None + assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' + # noinspection PyUnresolvedReferences if isinstance(self.stream[0], list): response = MockAsyncStream(iter(self.stream[self.index])) # type: ignore else: response = MockAsyncStream(iter(self.stream)) # type: ignore else: - assert self.completions is not None + assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' if isinstance(self.completions, list): response = self.completions[self.index] else: diff --git a/tests/test_agent.py b/tests/test_agent.py index 1b0514c7e..1bb555f9c 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,7 +6,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, CallContext, ModelRetry +from pydantic_ai import Agent, AgentError, CallContext, ModelRetry from pydantic_ai.messages import ( ArgsJson, ArgsObject, @@ -437,3 +437,23 @@ async def ret_a(x: str) -> str: _cost=Cost(), ) ) + + +def test_empty_tool_calls(): + def empty(_: list[Message], _info: AgentInfo) -> LLMMessage: + return LLMToolCalls(calls=[]) + + agent = Agent(FunctionModel(empty), deps=None) + + with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'): + agent.run_sync('Hello') + + +def test_unknown_retriever(): + def empty(_: list[Message], _info: AgentInfo) -> LLMMessage: + return LLMToolCalls(calls=[ToolCall.from_json('foobar', '{}')]) + + agent = Agent(FunctionModel(empty), deps=None) + + with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"): + agent.run_sync('Hello') diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 4e71e8178..8807b84a8 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,7 +1,12 @@ +import json +from collections.abc import Iterable + import pytest from inline_snapshot import snapshot -from pydantic_ai import Agent +from pydantic_ai import Agent, AgentError +from pydantic_ai.messages import Message, ToolReturn, UserPrompt +from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel pytestmark = pytest.mark.anyio @@ -68,3 +73,79 @@ async def test_streamed_text_stream(): assert [c async for c in result.stream(text_delta=True, debounce_by=None)] == snapshot( ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] ) + + +async def test_plain_response(): + call_index = 0 + + def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: + nonlocal call_index + + call_index += 1 + return ['hello ', 'world'] + + agent = Agent(FunctionModel(stream_function=text_stream), deps=None, result_type=tuple[str, str]) + + with pytest.raises(AgentError) as exc_info: + await agent.run_stream('') + + assert str(exc_info.value) == snapshot( + 'Error while running model function:stream-text_stream after 2 messages\n' + ' caused by unexpected model behavior: Exceeded maximum retries (1) for result validation' + ) + + +async def test_call_retriever(): + def stream_structured_function( + messages: list[Message], agent_info: AgentInfo + ) -> Iterable[DeltaToolCalls] | Iterable[str]: + if len(messages) == 1: + assert agent_info.retrievers is not None + assert len(agent_info.retrievers) == 1 + name = next(iter(agent_info.retrievers)) + first = messages[0] + assert isinstance(first, UserPrompt) + json_string = json.dumps({'x': first.content}) + yield {0: DeltaToolCall(name=name)} + yield {0: DeltaToolCall(args=json_string[:3])} + yield {0: DeltaToolCall(args=json_string[3:])} + else: + last = messages[-1] + assert isinstance(last, ToolReturn) + assert agent_info.result_tools is not None + assert len(agent_info.result_tools) == 1 + name = agent_info.result_tools[0].name + json_data = json.dumps({'response': [last.content, 2]}) + yield {0: DeltaToolCall(name=name)} + yield {0: DeltaToolCall(args=json_data[:5])} + yield {0: DeltaToolCall(args=json_data[5:])} + + agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) + + @agent.retriever_plain + async def ret_a(x: str) -> str: + assert x == 'hello' + return f'{x} world' + + result = await agent.run_stream('hello') + assert await result.get_response() == snapshot(('hello world', 2)) + + +async def test_call_retriever_empty(): + def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + yield {} + + agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) + + with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'): + await agent.run_stream('hello') + + +async def test_call_retriever_wrong_name(): + def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterable[DeltaToolCalls] | Iterable[str]: + yield {0: DeltaToolCall(name='foobar', args='{}')} + + agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) + + with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"): + await agent.run_stream('hello')