Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 8, 2024
1 parent 240778c commit 21cd229
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 46 deletions.
23 changes: 22 additions & 1 deletion pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import time
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass, is_dataclass
from functools import partial
from types import GenericAlias
Expand Down Expand Up @@ -183,3 +183,24 @@ async def group_by_temporal(aiter: AsyncIterator[T], soft_max_interval: float |
yield buffer
buffer = []
group_start_time = None


def add_optional(a: str | None, b: str | None) -> str | None:
"""Add two optional strings."""
if a is None:
return b
elif b is None:
return a
else:
return a + b


def sync_anext(iterator: Iterator[T]) -> T:
"""Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
Useful when iterating over a sync iterator in an async context.
"""
try:
return next(iterator)
except StopIteration as e:
raise StopAsyncIteration() from e
123 changes: 108 additions & 15 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations as _annotations

from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable
from itertools import chain
from typing import Callable, cast

from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, overload

from .. import result
from ..messages import LLMMessage, Message
from . import AbstractToolDefinition, AgentModel, Model

if TYPE_CHECKING:
from .._utils import ObjectJsonSchema
from .. import _utils, result
from ..messages import LLMMessage, LLMToolCalls, Message, ToolCall
from . import (
AbstractToolDefinition,
AgentModel,
EitherStreamedResponse,
Model,
StreamTextResponse,
StreamToolCallResponse,
)


@dataclass(frozen=True)
Expand All @@ -23,22 +28,44 @@ class AgentInfo:
result_tools: list[AbstractToolDefinition] | None


@dataclass
class DeltaToolCall:
name: str | None = None
args: str | None = None


DeltaToolCalls = dict[int, DeltaToolCall]

FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], LLMMessage]
StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Iterable[str] | Iterable[DeltaToolCalls]]


@dataclass
class ToolDescription:
name: str
description: str
json_schema: ObjectJsonSchema
json_schema: _utils.ObjectJsonSchema


@dataclass
@dataclass(init=False)
class FunctionModel(Model):
# NOTE: Avoid test discovery by pytest.
__test__ = False

function: FunctionDef
function: FunctionDef | None = None
stream_function: StreamFunctionDef | None = None

@overload
def __init__(self, function: FunctionDef) -> None: ...

@overload
def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...

def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
if function is None and stream_function is None:
raise TypeError('Either `function` or `stream_function` must be provided')
self.function = function
self.stream_function = stream_function

def agent_model(
self,
Expand All @@ -47,16 +74,82 @@ def agent_model(
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
result_tools = list(result_tools) if result_tools is not None else None
return FunctionAgentModel(self.function, AgentInfo(retrievers, allow_text_result, result_tools))
return FunctionAgentModel(
self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools)
)

def name(self) -> str:
return f'function:{self.function.__name__}'
labels: list[str] = []
if self.function is not None:
labels.append(self.function.__name__)
if self.stream_function is not None:
labels.append(f'stream-{self.stream_function.__name__}')
return f'function:{",".join(labels)}'


@dataclass
class FunctionAgentModel(AgentModel):
function: FunctionDef
function: FunctionDef | None
stream_function: StreamFunctionDef | None
agent_info: AgentInfo

async def request(self, messages: list[Message]) -> tuple[LLMMessage, result.Cost]:
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
return self.function(messages, self.agent_info), result.Cost()

async def request_stream(self, messages: list[Message]) -> EitherStreamedResponse:
assert (
self.stream_function is not None
), 'FunctionModel must receive a `stream_function` to support streamed requests'
response_data = iter(self.stream_function(messages, self.agent_info))
try:
first = next(response_data)
except StopIteration as e:
raise ValueError('Stream function must return at least one item') from e

if isinstance(first, str):
text_stream = cast(Iterable[str], response_data)
return FunctionStreamTextResponse(iter(chain([first], text_stream)))
else:
structured_stream = cast(Iterable[DeltaToolCalls], response_data)
# noinspection PyTypeChecker
return FunctionStreamToolCallResponse(iter(chain([first], structured_stream)), {})


@dataclass
class FunctionStreamTextResponse(StreamTextResponse):
_iter: Iterator[str]

async def __anext__(self) -> str:
return _utils.sync_anext(self._iter)

def cost(self) -> result.Cost:
return result.Cost()


@dataclass
class FunctionStreamToolCallResponse(StreamToolCallResponse):
_iter: Iterator[DeltaToolCalls]
_delta_tool_calls: dict[int, DeltaToolCall]

async def __anext__(self) -> None:
tool_call = _utils.sync_anext(self._iter)

for key, new in tool_call.items():
if current := self._delta_tool_calls.get(key):
current.name = _utils.add_optional(current.name, new.name)
current.args = _utils.add_optional(current.args, new.args)
else:
self._delta_tool_calls[key] = new

def get(self) -> LLMToolCalls:
"""Map tool call deltas to a `LLMToolCalls`."""
calls: list[ToolCall] = []
for c in self._delta_tool_calls.values():
if c.name is not None and c.args is not None:
calls.append(ToolCall.from_json(c.name, c.args))

return LLMToolCalls(calls)

def cost(self) -> result.Cost:
return result.Cost()
15 changes: 3 additions & 12 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from typing_extensions import assert_never

from .. import UnexpectedModelBehaviour, result
from .. import UnexpectedModelBehaviour, _utils, result
from ..messages import (
ArgsJson,
LLMMessage,
Expand Down Expand Up @@ -273,8 +273,8 @@ async def __anext__(self) -> None:
if current.function is None:
current.function = new.function
elif new.function is not None:
current.function.name = _add_optional(current.function.name, new.function.name)
current.function.arguments = _add_optional(current.function.arguments, new.function.arguments)
current.function.name = _utils.add_optional(current.function.name, new.function.name)
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
else:
self._delta_tool_calls[new.index] = new

Expand All @@ -292,15 +292,6 @@ def cost(self) -> Cost:
return self._cost


def _add_optional(a: str | None, b: str | None) -> str | None:
if a is None:
return b
elif b is None:
return a
else:
return a + b


def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
Expand Down
10 changes: 2 additions & 8 deletions pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,7 @@ def __post_init__(self):
self._iter = iter(words)

async def __anext__(self) -> str:
try:
return next(self._iter)
except StopIteration as e:
raise StopAsyncIteration() from e
return _utils.sync_anext(self._iter)

def cost(self) -> Cost:
return self._cost
Expand All @@ -200,10 +197,7 @@ class TestStreamToolCallResponse(StreamToolCallResponse):
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))

async def __anext__(self) -> None:
try:
return next(self._iter)
except StopIteration as e:
raise StopAsyncIteration() from e
return _utils.sync_anext(self._iter)

def get(self) -> LLMToolCalls:
return self._structured_response
Expand Down
49 changes: 48 additions & 1 deletion tests/models/test_model_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
from collections.abc import Iterable
from dataclasses import asdict
from datetime import timezone

Expand All @@ -20,10 +21,13 @@
ToolReturn,
UserPrompt,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import Cost
from tests.conftest import IsNow

pytestmark = pytest.mark.anyio


def return_last(messages: list[Message], _: AgentInfo) -> LLMMessage:
last = messages[-1]
Expand Down Expand Up @@ -393,3 +397,46 @@ async def validate_result(r: Foo) -> Foo:

result = agent.run_sync('')
assert result.response == snapshot(Foo(x=2))


def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str]:
yield 'hello '
yield 'world'


async def test_stream_text():
agent = Agent(FunctionModel(stream_function=stream_text_function), deps=None)
result = await agent.run_stream('')
assert await result.get_response() == snapshot('hello world')
assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))])
assert result.cost() == snapshot(Cost())


class Foo(BaseModel):
x: int


async def test_stream_structure():
def stream_structured_function(_messages: list[Message], agent_info: AgentInfo) -> Iterable[DeltaToolCalls]:
assert agent_info.result_tools is not None
assert len(agent_info.result_tools) == 1
name = agent_info.result_tools[0].name
yield {0: DeltaToolCall(name=name)}
yield {0: DeltaToolCall(args='{"x": ')}
yield {0: DeltaToolCall(args='1}')}

agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=Foo)
result = await agent.run_stream('')
assert await result.get_response() == snapshot(Foo(x=1))
assert result.cost() == snapshot(Cost())


async def test_pass_neither():
with pytest.raises(TypeError, match='Either `function` or `stream_function` must be provided'):
FunctionModel() # pyright: ignore[reportCallIssue]


async def test_return_empty():
agent = Agent(FunctionModel(stream_function=lambda _, __: []), deps=None)
with pytest.raises(ValueError, match='Stream function must return at least one item'):
await agent.run_stream('')
12 changes: 5 additions & 7 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails
from typing_extensions import TypedDict

from pydantic_ai import Agent, AgentError, ModelRetry
from pydantic_ai import Agent, AgentError, ModelRetry, _utils
from pydantic_ai.messages import (
ArgsJson,
LLMResponse,
Expand Down Expand Up @@ -52,10 +52,7 @@ class MockAsyncStream:
_iter: Iterator[chat.ChatCompletionChunk]

async def __anext__(self) -> chat.ChatCompletionChunk:
try:
return next(self._iter)
except StopIteration as e:
raise StopAsyncIteration() from e
return _utils.sync_anext(self._iter)


@dataclass
Expand Down Expand Up @@ -87,13 +84,14 @@ async def chat_completions_create(
self, *_args: Any, stream: bool = False, **_kwargs: Any
) -> chat.ChatCompletion | MockAsyncStream:
if stream:
assert self.stream is not None
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
# noinspection PyUnresolvedReferences
if isinstance(self.stream[0], list):
response = MockAsyncStream(iter(self.stream[self.index])) # type: ignore
else:
response = MockAsyncStream(iter(self.stream)) # type: ignore
else:
assert self.completions is not None
assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
if isinstance(self.completions, list):
response = self.completions[self.index]
else:
Expand Down
22 changes: 21 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from inline_snapshot import snapshot
from pydantic import BaseModel

from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai import Agent, AgentError, CallContext, ModelRetry
from pydantic_ai.messages import (
ArgsJson,
ArgsObject,
Expand Down Expand Up @@ -437,3 +437,23 @@ async def ret_a(x: str) -> str:
_cost=Cost(),
)
)


def test_empty_tool_calls():
def empty(_: list[Message], _info: AgentInfo) -> LLMMessage:
return LLMToolCalls(calls=[])

agent = Agent(FunctionModel(empty), deps=None)

with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'):
agent.run_sync('Hello')


def test_unknown_retriever():
def empty(_: list[Message], _info: AgentInfo) -> LLMMessage:
return LLMToolCalls(calls=[ToolCall.from_json('foobar', '{}')])

agent = Agent(FunctionModel(empty), deps=None)

with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"):
agent.run_sync('Hello')
Loading

0 comments on commit 21cd229

Please sign in to comment.