Skip to content

Commit

Permalink
make streams a context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 9, 2024
1 parent ffefa32 commit b12cf2f
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 159 deletions.
31 changes: 19 additions & 12 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations as _annotations

import asyncio
from collections.abc import Awaitable, Sequence
from collections.abc import AsyncIterator, Awaitable, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, final, overload

Expand Down Expand Up @@ -168,14 +169,15 @@ def run_sync(
"""
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))

@asynccontextmanager
async def run_stream(
self,
user_prompt: str,
*,
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> result.StreamedRunResult[AgentDeps, ResultData]:
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Args:
Expand Down Expand Up @@ -216,18 +218,23 @@ async def run_stream(

if left := either.left:
# left means return a streamed result
result_stream = left.value
run_span.set_attribute('all_messages', messages)
handle_span.set_attribute('result_type', left.value)
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
handle_span.message = 'handle model response -> final result'
return result.StreamedRunResult(
messages,
new_message_index,
cost,
left.value,
self._result_schema,
deps,
self._result_validators,
)
try:
yield result.StreamedRunResult(
messages,
new_message_index,
cost,
result_stream,
self._result_schema,
deps,
self._result_validators,
)
finally:
await result_stream.close()
return
else:
# right means continue the conversation
tool_responses = either.right
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -
def from_object(cls, tool_name: str, args_object: dict[str, Any]) -> ToolCall:
return cls(tool_name, ArgsObject(args_object))

def has_content(self) -> bool:
if isinstance(self.args, ArgsObject):
return any(self.args.args_object.values())
else:
return bool(self.args.args_json)


@dataclass
class LLMToolCalls:
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def cost(self) -> Cost:
"""
raise NotImplementedError()

async def close(self) -> None:
"""Close the response stream."""
pass


class StreamToolCallResponse(ABC):
"""Streamed response from an LLM when calling a tool."""
Expand Down Expand Up @@ -114,6 +118,10 @@ def cost(self) -> Cost:
"""
raise NotImplementedError()

async def close(self) -> None:
"""Close the response stream."""
pass


EitherStreamedResponse = Union[StreamTextResponse, StreamToolCallResponse]

Expand Down
38 changes: 24 additions & 14 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import os
import re
from collections.abc import Mapping, Sequence
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import asynccontextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import Annotated, Any, Literal, Union, cast

from httpx import AsyncClient as AsyncHTTPClient
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
from pydantic import Field
from typing_extensions import assert_never

Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
api_key: str | None = None,
http_client: AsyncHTTPClient | None = None,
# https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent',
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:{function}',
):
self.model_name = model_name
if api_key is None:
Expand Down Expand Up @@ -110,14 +111,20 @@ class GeminiAgentModel(AgentModel):
url_template: str

async def request(self, messages: list[Message]) -> tuple[LLMMessage, result.Cost]:
response = await self.make_request(messages)
return self.process_response(response), response.usage_metadata.as_cost()
async with self._make_request(messages, False) as http_response:
response = _gemini_response_ta.validate_json(await http_response.aread())
return self._process_response(response), response.usage_metadata.as_cost()

async def make_request(self, messages: list[Message]) -> _GeminiResponse:
# async def request_stream(self, messages: list[Message]) -> EitherStreamedResponse:
# """Make a request to the model and return a streaming response."""
# response = await self._make_request(messages, False)

@asynccontextmanager
async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]:
contents: list[_GeminiContent] = []
sys_prompt_parts: list[_GeminiTextPart] = []
for m in messages:
either_content = self.message_to_gemini(m)
either_content = self._message_to_gemini(m)
if left := either_content.left:
sys_prompt_parts.append(left.value)
else:
Expand All @@ -135,14 +142,17 @@ async def make_request(self, messages: list[Message]) -> _GeminiResponse:
'X-Goog-Api-Key': self.api_key,
'Content-Type': 'application/json',
}
url = self.url_template.format(model=self.model_name)
r = await self.http_client.post(url, content=request_json, headers=headers)
if r.status_code != 200:
raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text)
return _gemini_response_ta.validate_json(r.content)
url = self.url_template.format(
model=self.model_name, function='streamGenerateContent' if streamed else 'generateContent'
)

async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
if r.status_code != 200:
raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text)
yield r

@staticmethod
def process_response(response: _GeminiResponse) -> LLMMessage:
def _process_response(response: _GeminiResponse) -> LLMMessage:
assert len(response.candidates) == 1, 'Expected exactly one candidate'
parts = response.candidates[0].content.parts
if all(isinstance(part, _GeminiFunctionCallPart) for part in parts):
Expand All @@ -158,7 +168,7 @@ def process_response(response: _GeminiResponse) -> LLMMessage:
)

@staticmethod
def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
"""Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
if m.role == 'system':
# SystemPrompt ->
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ async def __anext__(self) -> str:
def cost(self) -> Cost:
return self._cost

async def close(self) -> None:
await self._response.close()


@dataclass
class OpenAIStreamToolCallResponse(StreamToolCallResponse):
Expand Down Expand Up @@ -291,6 +294,9 @@ def get(self) -> LLMToolCalls:
def cost(self) -> Cost:
return self._cost

async def close(self) -> None:
await self._response.close()


def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
Expand Down
27 changes: 15 additions & 12 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'StreamedRunResult',
)


ResultData = TypeVar('ResultData')


Expand Down Expand Up @@ -165,8 +166,10 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt
if isinstance(self._stream_response, models.StreamTextResponse):
raise exceptions.UserError('stream_messages() can only be used with structured responses')
else:
# we should already have a message at this point, yield that first
yield self._stream_response.get()
# we should already have a message at this point, yield that first if it has any content
initial_msg = self._stream_response.get()
if any(call.has_content() for call in initial_msg.calls):
yield initial_msg

async for _ in _utils.group_by_temporal(self._stream_response, debounce_by):
yield self._stream_response.get()
Expand Down Expand Up @@ -197,16 +200,6 @@ def cost(self) -> Cost:
"""
return self.cost_so_far + self._stream_response.cost()

async def _validate_text_result(self, text: str) -> str:
for validator in self._result_validators:
text = await validator.validate( # pyright: ignore[reportAssignmentType]
text, # pyright: ignore[reportArgumentType]
self._deps,
0,
None,
)
return text

async def validate_structured_result(
self, message: messages.LLMToolCalls, *, allow_partial: bool = False
) -> ResultData:
Expand All @@ -223,3 +216,13 @@ async def validate_structured_result(
for validator in self._result_validators:
result_data = await validator.validate(result_data, self._deps, 0, call)
return result_data

async def _validate_text_result(self, text: str) -> str:
for validator in self._result_validators:
text = await validator.validate( # pyright: ignore[reportAssignmentType]
text, # pyright: ignore[reportArgumentType]
self._deps,
0,
None,
)
return text
67 changes: 33 additions & 34 deletions pydantic_ai_examples/whales.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,41 +39,40 @@ async def main():
console = Console()
with Live('\n' * 36, console=console) as live:
console.print('Requesting data...', style='cyan')
result = await agent.run_stream('Generate me details of 20 species of Whale.')

console.print('Response:', style='green')

async for message in result.stream_structured(debounce_by=0.01):
try:
whales = await result.validate_structured_result(message, allow_partial=True)
except ValidationError as exc:
if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()):
continue
else:
raise

table = Table(
title='Species of Whale',
caption='Streaming Structured responses from GPT-4',
width=120,
)
table.add_column('ID', justify='right')
table.add_column('Name')
table.add_column('Avg. Length (m)', justify='right')
table.add_column('Avg. Weight (kg)', justify='right')
table.add_column('Ocean')
table.add_column('Description', justify='right')

for wid, whale in enumerate(whales, start=1):
table.add_row(
str(wid),
whale['name'],
f'{whale['length']:0.0f}',
f'{w:0.0f}' if (w := whale.get('weight')) else '…',
whale.get('ocean') or '…',
whale.get('description') or '…',
async with agent.run_stream('Generate me details of 20 species of Whale.') as result:
console.print('Response:', style='green')

async for message in result.stream_structured(debounce_by=0.01):
try:
whales = await result.validate_structured_result(message, allow_partial=True)
except ValidationError as exc:
if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()):
continue
else:
raise

table = Table(
title='Species of Whale',
caption='Streaming Structured responses from GPT-4',
width=120,
)
live.update(table)
table.add_column('ID', justify='right')
table.add_column('Name')
table.add_column('Avg. Length (m)', justify='right')
table.add_column('Avg. Weight (kg)', justify='right')
table.add_column('Ocean')
table.add_column('Description', justify='right')

for wid, whale in enumerate(whales, start=1):
table.add_row(
str(wid),
whale['name'],
f'{whale['length']:0.0f}',
f'{w:0.0f}' if (w := whale.get('weight')) else '…',
whale.get('ocean') or '…',
whale.get('description') or '…',
)
live.update(table)


if __name__ == '__main__':
Expand Down
17 changes: 9 additions & 8 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,10 @@ def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str

async def test_stream_text():
agent = Agent(FunctionModel(stream_function=stream_text_function), deps=None)
result = await agent.run_stream('')
assert await result.get_response() == snapshot('hello world')
assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))])
assert result.cost() == snapshot(Cost())
async with agent.run_stream('') as result:
assert await result.get_response() == snapshot('hello world')
assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))])
assert result.cost() == snapshot(Cost())


class Foo(BaseModel):
Expand All @@ -426,9 +426,9 @@ def stream_structured_function(_messages: list[Message], agent_info: AgentInfo)
yield {0: DeltaToolCall(args='1}')}

agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=Foo)
result = await agent.run_stream('')
assert await result.get_response() == snapshot(Foo(x=1))
assert result.cost() == snapshot(Cost())
async with agent.run_stream('') as result:
assert await result.get_response() == snapshot(Foo(x=1))
assert result.cost() == snapshot(Cost())


async def test_pass_neither():
Expand All @@ -439,4 +439,5 @@ async def test_pass_neither():
async def test_return_empty():
agent = Agent(FunctionModel(stream_function=lambda _, __: []), deps=None)
with pytest.raises(ValueError, match='Stream function must return at least one item'):
await agent.run_stream('')
async with agent.run_stream(''):
pass
Loading

0 comments on commit b12cf2f

Please sign in to comment.