Skip to content

Commit

Permalink
working whales examples
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 7, 2024
1 parent edc9753 commit b6693ed
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 81 deletions.
24 changes: 17 additions & 7 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,16 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
outer_typed_dict_key=outer_typed_dict_key,
)

def validate(self, tool_call: messages.ToolCall, allow_partial: bool = False) -> ResultData:
def validate(
self, tool_call: messages.ToolCall, allow_partial: bool = False, wrap_validation_errors: bool = True
) -> ResultData:
"""Validate a result message.
Args:
tool_call: The tool call from the LLM to validate.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.
Returns:
Either the validated result data (left) or a retry message (right).
"""
Expand All @@ -194,12 +201,15 @@ def validate(self, tool_call: messages.ToolCall, allow_partial: bool = False) ->
tool_call.args.args_object, experimental_allow_partial=allow_partial
)
except ValidationError as e:
m = messages.RetryPrompt(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
)
raise ToolRetryError(m) from e
if wrap_validation_errors:
m = messages.RetryPrompt(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
)
raise ToolRetryError(m) from e
else:
raise
else:
if k := self.outer_typed_dict_key:
result = result[k]
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ async def __anext__(self) -> None:
if choice.finish_reason is not None:
raise StopAsyncIteration()

assert choice.delta.tool_calls is not None, f'Expected delta with tool calls, invalid chunk: {chunk!r}'
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'

for new in choice.delta.tool_calls:
for new in choice.delta.tool_calls or []:
if current := self._delta_tool_calls.get(new.index):
if current.function is None:
current.function = new.function
Expand Down
108 changes: 38 additions & 70 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Generic, TypeVar
from typing import Generic, TypeVar, cast

from . import _result, _utils, exceptions, messages, models
from .call_typing import AgentDeps
Expand Down Expand Up @@ -88,64 +88,6 @@ def cost(self) -> Cost:
return self._cost


# @dataclass
# class StreamedTextRunResult(_BaseRunResult[str], Generic[AgentDeps]):
# """Text result of a streamed run."""
#
# cost_so_far: Cost
# """Cost up until the last request."""
# _stream_response: models.StreamTextResponse
# _deps: AgentDeps
# _result_validators: list[_result.ResultValidator[AgentDeps, str]]
#
# async def stream(self, text_delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
# """Stream the response text as an async iterable.
#
# Result validators are called on each iteration, if `text_delta=False`.
#
# !!!
# Note this means that the result validators will NOT be called on the final result if `text_delta=True`.
#
# Args:
# text_delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
# up to the current point.
# debounce_by: by how much (if at all) to debounce/group the response chunks by. if `AUTO` (default),
# the response stream is debounced by 0.1 seconds unless `text_delta` is `True`, in which case it
# doesn't make sense to debounce. `None` means no debouncing. Debouncing is important particularly
# for long structured responses to reduce the overhead of performing validation as each token is received.
#
# Returns: An async iterable of the response data.
# """
# if text_delta:
# async for chunks in _utils.group_by_temporal(self._stream_response, debounce_by):
# yield ''.join(chunks)
# else:
# # a quick benchmark shows it's faster build up a string with concat when we're yielding at each step
# combined = ''
# async for chunks in _utils.group_by_temporal(self._stream_response, debounce_by):
# combined += ''.join(chunks)
# combined = await self._validate_result(combined)
# yield combined
#
# async def get_response(self) -> str:
# """Stream the whole response, validate and return it."""
# text = ''.join([chunk async for chunk in self._stream_response])
# return await self._validate_result(text)
#
# def cost(self) -> Cost:
# """Return the cost of the whole run.
#
# NOTE: this won't return the full cost until the stream is finished.
# """
# return self.cost_so_far + self._stream_response.cost()
#
# async def _validate_result(self, text: str) -> str:
# for validator in self._result_validators:
# text = await validator.validate(text, self._deps, 0, None)
# return text
#


@dataclass
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
"""Result of a streamed run that returns structured data via a tool call."""
Expand All @@ -157,7 +99,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
_deps: AgentDeps
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]

async def stream(self, *, text_delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[ResultData]:
async def stream(
self,
*,
text_delta: bool = False,
debounce_by: float | None = 0.1,
) -> AsyncIterator[ResultData]:
"""Stream the response as an async iterable.
Result validators are called on each iteration, if `text_delta=False` (the default) or for structured
Expand All @@ -183,32 +130,51 @@ async def stream(self, *, text_delta: bool = False, debounce_by: float | None =
async for chunks in _utils.group_by_temporal(self._stream_response, debounce_by):
yield ''.join(chunks) # pyright: ignore[reportReturnType]
else:
# a quick benchmark shows it's faster build up a string with concat when we're yielding at each step
# a quick benchmark shows it's faster to build up a string with concat when we're yielding at each step
combined = ''
async for chunks in _utils.group_by_temporal(self._stream_response, debounce_by):
combined += ''.join(chunks)
combined = await self._validate_text_result(combined)
yield combined # pyright: ignore[reportReturnType]
yield cast(ResultData, combined)
else:
assert not text_delta, 'Cannot use `text_delta=True` for structured responses'
async for _ in _utils.group_by_temporal(self._stream_response, debounce_by):
tool_message = self._stream_response.get()
yield await self._validate_tool_result(tool_message, True)
yield await self.validate_structured_result(tool_message, allow_partial=True)

async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[messages.LLMToolCalls]:
"""Stream the response as an async iterable of Structured LLM Messages.
!!! note
This method is only available for structured responses, e.g. if `is_structured()` returns `True`.
Args:
debounce_by: by how much (if at all) to debounce/group the response chunks by. if `AUTO` (default),
the response stream is debounced by 0.2 seconds unless `text_delta` is `True`, in which case it
doesn't make sense to debounce. `None` means no debouncing. Debouncing is important particularly
for long structured responses to reduce the overhead of performing validation as each token is received.
"""
if isinstance(self._stream_response, models.StreamTextResponse):
raise exceptions.UserError('stream_messages() can only be used with structured responses')
else:
async for _ in _utils.group_by_temporal(self._stream_response, debounce_by):
yield self._stream_response.get()

async def get_response(self) -> ResultData:
"""Stream the whole response, validate and return it."""
if isinstance(self._stream_response, models.StreamTextResponse):
text = ''.join([chunk async for chunk in self._stream_response])
return await self._validate_text_result(text) # pyright: ignore[reportReturnType]
text = await self._validate_text_result(text)
return cast(ResultData, text)
else:
async for _ in self._stream_response:
pass
tool_message = self._stream_response.get()
return await self._validate_tool_result(tool_message, False)
return await self.validate_structured_result(tool_message)

def is_text(self) -> bool:
"""Return whether the stream response is text."""
return isinstance(self._stream_response, models.StreamTextResponse)
def is_structured(self) -> bool:
"""Return whether the stream response contains structured data (as opposed to text)."""
return isinstance(self._stream_response, models.StreamToolCallResponse)

def cost(self) -> Cost:
"""Return the cost of the whole run.
Expand All @@ -227,7 +193,9 @@ async def _validate_text_result(self, text: str) -> str:
)
return text

async def _validate_tool_result(self, message: messages.LLMToolCalls, allow_partial: bool) -> ResultData:
async def validate_structured_result(
self, message: messages.LLMToolCalls, *, allow_partial: bool = False
) -> ResultData:
assert self._result_schema is not None, 'Expected _result_schema to not be None'
match = self._result_schema.find_tool(message)
if match is None:
Expand All @@ -236,7 +204,7 @@ async def _validate_tool_result(self, message: messages.LLMToolCalls, allow_part
)

call, result_tool = match
result_data = result_tool.validate(call, allow_partial=allow_partial)
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)

for validator in self._result_validators:
result_data = await validator.validate(result_data, self._deps, 0, call)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Success(BaseModel):
"""Response when SQL could be successfully generated."""

sql_query: Annotated[str, MinLen(1)]
explanation: str = Field(None, description='Explanation of the SQL query, as markdown')
explanation: str = Field('', description='Explanation of the SQL query, as markdown')


class InvalidRequest(BaseModel):
Expand Down
71 changes: 71 additions & 0 deletions pydantic_ai_examples/whales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Annotated, NotRequired, TypedDict

import devtools
from pydantic import Field, ValidationError
from rich.console import Console
from rich.live import Live
from rich.table import Table

from pydantic_ai import Agent


class Whale(TypedDict):
name: str
length: Annotated[float, Field(description='Average length of an adult whale in meters.')]
ocean: NotRequired[str]
description: NotRequired[Annotated[str, Field(description='Short Description')]]


agent = Agent('openai:gpt-4', result_type=list[Whale], deps=None)


def check_validation_error(e: ValidationError) -> bool:
devtools.debug(e.errors())
return False


async def main():
console = Console()
with Live('\n' * 36, console=console) as live:
console.print('Requesting data...', style='cyan')
result = await agent.run_stream(
'Generate me details of 30 species of Whale.',
)

console.print('Response:', style='green')

async for message in result.stream_structured(debounce_by=0.01):
try:
whales = await result.validate_structured_result(message, allow_partial=True)
except ValidationError as exc:
if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()):
continue
else:
raise

table = Table(
title='Species of Whale',
caption='Streaming Structured responses from GPT-4',
width=120,
)
table.add_column('ID', justify='right')
table.add_column('Name')
table.add_column('Avg. Length (m)', justify='right')
table.add_column('Ocean')
table.add_column('Description', justify='right')

for wid, whale in enumerate(whales, start=1):
table.add_row(
str(wid),
whale['name'],
f'{whale['length']:0.0f}',
whale.get('ocean') or '…',
whale.get('description') or '…',
)
live.update(table)


if __name__ == '__main__':
import asyncio

asyncio.run(main())
2 changes: 1 addition & 1 deletion tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class TestModel(BaseModel):
nested: NestedModel
union: int | list[int]
optional: str | None
with_example: int = Field(..., json_schema_extra={'examples': [1234]})
with_example: int = Field(json_schema_extra={'examples': [1234]})
max_len_zero: Annotated[str, MaxLen(0)]
is_null: None
not_required: str = 'default'
Expand Down

0 comments on commit b6693ed

Please sign in to comment.