From b12cf2fb295796bae79caeea0282892219b39d19 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 9 Nov 2024 11:29:16 +0000 Subject: [PATCH] make streams a context manager --- pydantic_ai/agent.py | 31 +++++---- pydantic_ai/messages.py | 6 ++ pydantic_ai/models/__init__.py | 8 +++ pydantic_ai/models/gemini.py | 38 +++++++---- pydantic_ai/models/openai.py | 6 ++ pydantic_ai/result.py | 27 ++++---- pydantic_ai_examples/whales.py | 67 +++++++++--------- tests/models/test_model_function.py | 17 ++--- tests/models/test_openai.py | 66 +++++++++--------- tests/test_streaming.py | 101 ++++++++++++++++------------ tests/typed_agent.py | 4 +- 11 files changed, 212 insertions(+), 159 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 3a4af966c..3f1e25a63 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -1,7 +1,8 @@ from __future__ import annotations as _annotations import asyncio -from collections.abc import Awaitable, Sequence +from collections.abc import AsyncIterator, Awaitable, Sequence +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -168,6 +169,7 @@ def run_sync( """ return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps)) + @asynccontextmanager async def run_stream( self, user_prompt: str, @@ -175,7 +177,7 @@ async def run_stream( message_history: list[_messages.Message] | None = None, model: models.Model | KnownModelName | None = None, deps: AgentDeps | None = None, - ) -> result.StreamedRunResult[AgentDeps, ResultData]: + ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. Args: @@ -216,18 +218,23 @@ async def run_stream( if left := either.left: # left means return a streamed result + result_stream = left.value run_span.set_attribute('all_messages', messages) - handle_span.set_attribute('result_type', left.value) + handle_span.set_attribute('result_type', result_stream.__class__.__name__) handle_span.message = 'handle model response -> final result' - return result.StreamedRunResult( - messages, - new_message_index, - cost, - left.value, - self._result_schema, - deps, - self._result_validators, - ) + try: + yield result.StreamedRunResult( + messages, + new_message_index, + cost, + result_stream, + self._result_schema, + deps, + self._result_validators, + ) + finally: + await result_stream.close() + return else: # right means continue the conversation tool_responses = either.right diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index 865533138..d3895d59c 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -102,6 +102,12 @@ def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) - def from_object(cls, tool_name: str, args_object: dict[str, Any]) -> ToolCall: return cls(tool_name, ArgsObject(args_object)) + def has_content(self) -> bool: + if isinstance(self.args, ArgsObject): + return any(self.args.args_object.values()) + else: + return bool(self.args.args_json) + @dataclass class LLMToolCalls: diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index ffcc25234..9679ee54a 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -81,6 +81,10 @@ def cost(self) -> Cost: """ raise NotImplementedError() + async def close(self) -> None: + """Close the response stream.""" + pass + class StreamToolCallResponse(ABC): """Streamed response from an LLM when calling a tool.""" @@ -114,6 +118,10 @@ def cost(self) -> Cost: """ raise NotImplementedError() + async def close(self) -> None: + """Close the response stream.""" + pass + EitherStreamedResponse = Union[StreamTextResponse, StreamToolCallResponse] diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 8db2568e0..8a961fbe1 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -18,12 +18,13 @@ import os import re -from collections.abc import Mapping, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence +from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass from typing import Annotated, Any, Literal, Union, cast -from httpx import AsyncClient as AsyncHTTPClient +from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse from pydantic import Field from typing_extensions import assert_never @@ -60,7 +61,7 @@ def __init__( api_key: str | None = None, http_client: AsyncHTTPClient | None = None, # https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request - url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent', + url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:{function}', ): self.model_name = model_name if api_key is None: @@ -110,14 +111,20 @@ class GeminiAgentModel(AgentModel): url_template: str async def request(self, messages: list[Message]) -> tuple[LLMMessage, result.Cost]: - response = await self.make_request(messages) - return self.process_response(response), response.usage_metadata.as_cost() + async with self._make_request(messages, False) as http_response: + response = _gemini_response_ta.validate_json(await http_response.aread()) + return self._process_response(response), response.usage_metadata.as_cost() - async def make_request(self, messages: list[Message]) -> _GeminiResponse: + # async def request_stream(self, messages: list[Message]) -> EitherStreamedResponse: + # """Make a request to the model and return a streaming response.""" + # response = await self._make_request(messages, False) + + @asynccontextmanager + async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]: contents: list[_GeminiContent] = [] sys_prompt_parts: list[_GeminiTextPart] = [] for m in messages: - either_content = self.message_to_gemini(m) + either_content = self._message_to_gemini(m) if left := either_content.left: sys_prompt_parts.append(left.value) else: @@ -135,14 +142,17 @@ async def make_request(self, messages: list[Message]) -> _GeminiResponse: 'X-Goog-Api-Key': self.api_key, 'Content-Type': 'application/json', } - url = self.url_template.format(model=self.model_name) - r = await self.http_client.post(url, content=request_json, headers=headers) - if r.status_code != 200: - raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text) - return _gemini_response_ta.validate_json(r.content) + url = self.url_template.format( + model=self.model_name, function='streamGenerateContent' if streamed else 'generateContent' + ) + + async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r: + if r.status_code != 200: + raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text) + yield r @staticmethod - def process_response(response: _GeminiResponse) -> LLMMessage: + def _process_response(response: _GeminiResponse) -> LLMMessage: assert len(response.candidates) == 1, 'Expected exactly one candidate' parts = response.candidates[0].content.parts if all(isinstance(part, _GeminiFunctionCallPart) for part in parts): @@ -158,7 +168,7 @@ def process_response(response: _GeminiResponse) -> LLMMessage: ) @staticmethod - def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]: + def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]: """Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents".""" if m.role == 'system': # SystemPrompt -> diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 3515fcc6d..0216402c0 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -247,6 +247,9 @@ async def __anext__(self) -> str: def cost(self) -> Cost: return self._cost + async def close(self) -> None: + await self._response.close() + @dataclass class OpenAIStreamToolCallResponse(StreamToolCallResponse): @@ -291,6 +294,9 @@ def get(self) -> LLMToolCalls: def cost(self) -> Cost: return self._cost + async def close(self) -> None: + await self._response.close() + def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str: """Type guard that checks a `tool_id` is not None both for static typing and runtime.""" diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index e0bb1cbe1..b93a8321d 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -15,6 +15,7 @@ 'StreamedRunResult', ) + ResultData = TypeVar('ResultData') @@ -165,8 +166,10 @@ async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIt if isinstance(self._stream_response, models.StreamTextResponse): raise exceptions.UserError('stream_messages() can only be used with structured responses') else: - # we should already have a message at this point, yield that first - yield self._stream_response.get() + # we should already have a message at this point, yield that first if it has any content + initial_msg = self._stream_response.get() + if any(call.has_content() for call in initial_msg.calls): + yield initial_msg async for _ in _utils.group_by_temporal(self._stream_response, debounce_by): yield self._stream_response.get() @@ -197,16 +200,6 @@ def cost(self) -> Cost: """ return self.cost_so_far + self._stream_response.cost() - async def _validate_text_result(self, text: str) -> str: - for validator in self._result_validators: - text = await validator.validate( # pyright: ignore[reportAssignmentType] - text, # pyright: ignore[reportArgumentType] - self._deps, - 0, - None, - ) - return text - async def validate_structured_result( self, message: messages.LLMToolCalls, *, allow_partial: bool = False ) -> ResultData: @@ -223,3 +216,13 @@ async def validate_structured_result( for validator in self._result_validators: result_data = await validator.validate(result_data, self._deps, 0, call) return result_data + + async def _validate_text_result(self, text: str) -> str: + for validator in self._result_validators: + text = await validator.validate( # pyright: ignore[reportAssignmentType] + text, # pyright: ignore[reportArgumentType] + self._deps, + 0, + None, + ) + return text diff --git a/pydantic_ai_examples/whales.py b/pydantic_ai_examples/whales.py index b9aeabeda..66f79294c 100644 --- a/pydantic_ai_examples/whales.py +++ b/pydantic_ai_examples/whales.py @@ -39,41 +39,40 @@ async def main(): console = Console() with Live('\n' * 36, console=console) as live: console.print('Requesting data...', style='cyan') - result = await agent.run_stream('Generate me details of 20 species of Whale.') - - console.print('Response:', style='green') - - async for message in result.stream_structured(debounce_by=0.01): - try: - whales = await result.validate_structured_result(message, allow_partial=True) - except ValidationError as exc: - if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()): - continue - else: - raise - - table = Table( - title='Species of Whale', - caption='Streaming Structured responses from GPT-4', - width=120, - ) - table.add_column('ID', justify='right') - table.add_column('Name') - table.add_column('Avg. Length (m)', justify='right') - table.add_column('Avg. Weight (kg)', justify='right') - table.add_column('Ocean') - table.add_column('Description', justify='right') - - for wid, whale in enumerate(whales, start=1): - table.add_row( - str(wid), - whale['name'], - f'{whale['length']:0.0f}', - f'{w:0.0f}' if (w := whale.get('weight')) else '…', - whale.get('ocean') or '…', - whale.get('description') or '…', + async with agent.run_stream('Generate me details of 20 species of Whale.') as result: + console.print('Response:', style='green') + + async for message in result.stream_structured(debounce_by=0.01): + try: + whales = await result.validate_structured_result(message, allow_partial=True) + except ValidationError as exc: + if all(e['type'] == 'missing' and e['loc'] == ('response',) for e in exc.errors()): + continue + else: + raise + + table = Table( + title='Species of Whale', + caption='Streaming Structured responses from GPT-4', + width=120, ) - live.update(table) + table.add_column('ID', justify='right') + table.add_column('Name') + table.add_column('Avg. Length (m)', justify='right') + table.add_column('Avg. Weight (kg)', justify='right') + table.add_column('Ocean') + table.add_column('Description', justify='right') + + for wid, whale in enumerate(whales, start=1): + table.add_row( + str(wid), + whale['name'], + f'{whale['length']:0.0f}', + f'{w:0.0f}' if (w := whale.get('weight')) else '…', + whale.get('ocean') or '…', + whale.get('description') or '…', + ) + live.update(table) if __name__ == '__main__': diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index ec02457ab..be370691a 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -406,10 +406,10 @@ def stream_text_function(_messages: list[Message], _: AgentInfo) -> Iterable[str async def test_stream_text(): agent = Agent(FunctionModel(stream_function=stream_text_function), deps=None) - result = await agent.run_stream('') - assert await result.get_response() == snapshot('hello world') - assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))]) - assert result.cost() == snapshot(Cost()) + async with agent.run_stream('') as result: + assert await result.get_response() == snapshot('hello world') + assert result.all_messages() == snapshot([UserPrompt(content='', timestamp=IsNow(tz=timezone.utc))]) + assert result.cost() == snapshot(Cost()) class Foo(BaseModel): @@ -426,9 +426,9 @@ def stream_structured_function(_messages: list[Message], agent_info: AgentInfo) yield {0: DeltaToolCall(args='1}')} agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=Foo) - result = await agent.run_stream('') - assert await result.get_response() == snapshot(Foo(x=1)) - assert result.cost() == snapshot(Cost()) + async with agent.run_stream('') as result: + assert await result.get_response() == snapshot(Foo(x=1)) + assert result.cost() == snapshot(Cost()) async def test_pass_neither(): @@ -439,4 +439,5 @@ async def test_pass_neither(): async def test_return_empty(): agent = Agent(FunctionModel(stream_function=lambda _, __: []), deps=None) with pytest.raises(ValueError, match='Stream function must return at least one item'): - await agent.run_stream('') + async with agent.run_stream(''): + pass diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 358714bf0..2239b53e0 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -54,6 +54,9 @@ class MockAsyncStream: async def __anext__(self) -> chat.ChatCompletionChunk: return _utils.sync_anext(self._iter) + async def close(self): + pass + @dataclass class MockOpenAI: @@ -313,13 +316,12 @@ async def test_stream_text(): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, deps=None) - result = await agent.run_stream('') - - assert not result.is_structured() - assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) - assert result.is_complete - assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9)) + async with agent.run_stream('') as result: + assert not result.is_structured() + assert not result.is_complete + assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9)) async def test_stream_text_finish_reason(): @@ -328,12 +330,11 @@ async def test_stream_text_finish_reason(): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, deps=None) - result = await agent.run_stream('') - - assert not result.is_structured() - assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) - assert result.is_complete + async with agent.run_stream('') as result: + assert not result.is_structured() + assert not result.is_complete + assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete def struc_chunk( @@ -375,17 +376,16 @@ async def test_stream_structured(): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, deps=None, result_type=MyTypedDict) - result = await agent.run_stream('') - - assert result.is_structured() - assert not result.is_complete - assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] - ) - assert result.is_complete - assert result.cost() == snapshot(Cost(request_tokens=20, response_tokens=10, total_tokens=30)) - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + async with agent.run_stream('') as result: + assert result.is_structured() + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert result.is_complete + assert result.cost() == snapshot(Cost(request_tokens=20, response_tokens=10, total_tokens=30)) + # double check cost matches stream count + assert result.cost().response_tokens == len(stream) async def test_stream_structured_finish_reason(): @@ -400,14 +400,13 @@ async def test_stream_structured_finish_reason(): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, deps=None, result_type=MyTypedDict) - result = await agent.run_stream('') - - assert result.is_structured() - assert not result.is_complete - assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( - [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] - ) - assert result.is_complete + async with agent.run_stream('') as result: + assert result.is_structured() + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert result.is_complete async def test_no_content(): @@ -417,4 +416,5 @@ async def test_no_content(): agent = Agent(m, deps=None, result_type=MyTypedDict) with pytest.raises(AgentError, match='caused by unexpected model behavior: Streamed response ended without con'): - await agent.run_stream('') + async with agent.run_stream(''): + pass diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c87ee2062..014fa601d 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,14 +2,16 @@ import json from collections.abc import Iterable +from datetime import timezone import pytest from inline_snapshot import snapshot from pydantic_ai import Agent, AgentError -from pydantic_ai.messages import Message, ToolReturn, UserPrompt +from pydantic_ai.messages import ArgsJson, LLMToolCalls, Message, ToolCall, ToolReturn, UserPrompt from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from scratch.conftest_with_create_module import IsNow pytestmark = pytest.mark.anyio @@ -23,13 +25,12 @@ async def test_streamed_text_response(): async def ret_a(x: str) -> str: return f'{x}-apple' - result = await agent.run_stream('Hello') - - assert not result.is_structured() - assert not result.is_complete - response = await result.get_response() - assert response == snapshot('{"ret_a":"a-apple"}') - assert result.is_complete + async with agent.run_stream('Hello') as result: + assert not result.is_structured() + assert not result.is_complete + response = await result.get_response() + assert response == snapshot('{"ret_a":"a-apple"}') + assert result.is_complete async def test_streamed_structured_response(): @@ -37,13 +38,12 @@ async def test_streamed_structured_response(): agent = Agent(m, deps=None, result_type=tuple[str, str]) - result = await agent.run_stream('') - - assert result.is_structured() - assert not result.is_complete - response = await result.get_response() - assert response == snapshot(('a', 'a')) - assert result.is_complete + async with agent.run_stream('') as result: + assert result.is_structured() + assert not result.is_complete + response = await result.get_response() + assert response == snapshot(('a', 'a')) + assert result.is_complete async def test_streamed_text_stream(): @@ -51,30 +51,30 @@ async def test_streamed_text_stream(): agent = Agent(m, deps=None) - result = await agent.run_stream('Hello') - assert not result.is_structured() - # typehint to test (via static typing) that the stream type is correctly inferred - chunks: list[str] = [c async for c in result.stream()] - # one chunk due to group_by_temporal - assert chunks == snapshot(['The cat sat on the mat.']) - assert result.is_complete - - result = await agent.run_stream('Hello') - assert [c async for c in result.stream(debounce_by=None)] == snapshot( - [ - 'The ', - 'The cat ', - 'The cat sat ', - 'The cat sat on ', - 'The cat sat on the ', - 'The cat sat on the mat.', - ] - ) - - result = await agent.run_stream('Hello') - assert [c async for c in result.stream(text_delta=True, debounce_by=None)] == snapshot( - ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] - ) + async with agent.run_stream('Hello') as result: + assert not result.is_structured() + # typehint to test (via static typing) that the stream type is correctly inferred + chunks: list[str] = [c async for c in result.stream()] + # one chunk due to group_by_temporal + assert chunks == snapshot(['The cat sat on the mat.']) + assert result.is_complete + + async with agent.run_stream('Hello') as result: + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + 'The ', + 'The cat ', + 'The cat sat ', + 'The cat sat on ', + 'The cat sat on the ', + 'The cat sat on the mat.', + ] + ) + + async with agent.run_stream('Hello') as result: + assert [c async for c in result.stream(text_delta=True, debounce_by=None)] == snapshot( + ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] + ) async def test_plain_response(): @@ -89,7 +89,8 @@ def text_stream(_messages: list[Message], _: AgentInfo) -> list[str]: agent = Agent(FunctionModel(stream_function=text_stream), deps=None, result_type=tuple[str, str]) with pytest.raises(AgentError) as exc_info: - await agent.run_stream('') + async with agent.run_stream(''): + pass assert str(exc_info.value) == snapshot( 'Error while running model function:stream-text_stream after 2 messages\n' @@ -129,8 +130,18 @@ async def ret_a(x: str) -> str: assert x == 'hello' return f'{x} world' - result = await agent.run_stream('hello') - assert await result.get_response() == snapshot(('hello world', 2)) + async with agent.run_stream('hello') as result: + assert await result.get_response() == snapshot(('hello world', 2)) + assert result.all_messages() == snapshot( + [ + UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)), + LLMToolCalls( + calls=[ToolCall(tool_name='ret_a', args=ArgsJson(args_json='{"x": "hello"}'))], + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturn(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc)), + ] + ) async def test_call_retriever_empty(): @@ -140,7 +151,8 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) with pytest.raises(AgentError, match='caused by unexpected model behavior: Received empty tool call message'): - await agent.run_stream('hello') + async with agent.run_stream('hello'): + pass async def test_call_retriever_wrong_name(): @@ -150,4 +162,5 @@ def stream_structured_function(_messages: list[Message], _: AgentInfo) -> Iterab agent = Agent(FunctionModel(stream_function=stream_structured_function), deps=None, result_type=tuple[str, int]) with pytest.raises(AgentError, match="caused by unexpected model behavior: Unknown function name: 'foobar'"): - await agent.run_stream('hello') + async with agent.run_stream('hello'): + pass diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 1535eed15..28ddf41cf 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -67,8 +67,8 @@ def run_sync() -> None: async def run_stream() -> None: - streamed_result = await typed_agent1.run_stream('testing') - _: list[str] = [chunk async for chunk in streamed_result.stream()] + async with typed_agent1.run_stream('testing') as streamed_result: + _: list[str] = [chunk async for chunk in streamed_result.stream()] typed_agent2: Agent[MyDeps, str] = Agent()