diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index bfeb9850a..3465decb0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -3,15 +3,16 @@ import asyncio import dataclasses import inspect -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Generator, Iterator, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from copy import deepcopy from types import FrameType from typing import Any, Callable, Generic, cast, final, overload import logfire_api -from typing_extensions import TypeVar, deprecated +from typing_extensions import Self, TypeVar, deprecated -from pydantic_graph import Graph, GraphRunContext, HistoryStep +from pydantic_graph import BaseNode, Graph, GraphRun, GraphRunContext, HistoryStep from pydantic_graph.nodes import End from . import ( @@ -25,7 +26,7 @@ result, usage as _usage, ) -from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export +from ._agent_graph import EndStrategy, MarkFinalResult, capture_run_messages # imported for re-export from .result import ResultDataT from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -214,7 +215,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... @overload async def run( @@ -229,7 +230,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... async def run( self, @@ -243,7 +244,7 @@ async def run( usage: _usage.Usage | None = None, result_type: type[RunResultDataT] | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: + ) -> AgentRun[AgentDepsT, ResultDataT]: """Run the agent with a user prompt in async mode. Example: @@ -305,54 +306,47 @@ async def main(): model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() - with _logfire.span( + # Build the deps object for the graph + run_span = _logfire.span( '{agent_name} run {prompt=}', prompt=user_prompt, agent=self, model_name=model_used.name() if model_used else 'no-model', agent_name=self.name or 'agent', - ) as run_span: - # Build the deps object for the graph - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, - ) - - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) + ) + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + result_schema=result_schema, + result_tools=self._result_schema.tool_defs() if self._result_schema else [], + result_validators=result_validators, + function_tools=self._function_tools, + run_span=run_span, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + ) - # Actually run - graph_run = await graph.run( + # Actually run + # TODO: Make this method non-async and remove the next await + # That way, users can decide whether to "await" the run, or iterate over it + return await AgentRun( + graph.run( start_node, state=state, deps=graph_deps, infer_name=False, + span=run_span, ) - end_result = graph_run.result - - # Build final run result - # We don't do any advanced checking if the data is actually from a final result or not - return result.RunResult( - state.message_history, - new_message_index, - end_result.data, - end_result.tool_name, - state.usage, ) @overload @@ -367,7 +361,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... @overload def run_sync( @@ -382,7 +376,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... def run_sync( self, @@ -396,7 +390,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: + ) -> AgentRun[AgentDepsT, ResultDataT]: """Run the agent with a user prompt synchronously. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. @@ -1040,7 +1034,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: def _build_graph( self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: + ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[Any]]: return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) def _build_stream_graph( @@ -1059,3 +1053,140 @@ def _prepare_result_schema( ) else: return self._result_schema # pyright: ignore[reportReturnType] + + +@dataclasses.dataclass +class AgentRun(Generic[AgentDepsT, ResultDataT]): + # TODO: Should this go into the `result` module? And replace `result.RunResult`? + graph_run: GraphRun[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + if result_tool_return_content is not None: + return self._set_result_tool_return(result_tool_return_content) + else: + return self.graph_run.state.message_history + + def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.all_messages(result_tool_return_content=result_tool_return_content) + ) + + @property + def _new_message_index(self) -> int: + return self.graph_run.deps.new_message_index + + def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] + + def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the new messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.new_messages(result_tool_return_content=result_tool_return_content) + ) + + def usage(self) -> _usage.Usage: + return self.graph_run.state.usage + + def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: + """Set return content for the result tool. + + Useful if you want to continue the conversation and want to set the response to the result tool call. + """ + if not self.result.tool_name: + raise ValueError('Cannot set result tool return content when the return type is `str`.') + messages = deepcopy(self.graph_run.state.message_history) + last_message = messages[-1] + for part in last_message.parts: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name: + part.content = return_content + return messages + raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.') + + @property + def is_ended(self): + return self.graph_run.is_ended + + @property + def result(self) -> MarkFinalResult[ResultDataT]: + return self.graph_run.result + + @property + def _result_tool_name(self) -> str | None: + return self.graph_run.result.tool_name + + @property + def data(self) -> ResultDataT: + return self.result.data + + async def next( + self, + node: BaseNode[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ], + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + | End[MarkFinalResult[ResultDataT]] + ): + return await self.graph_run.next(node) + + def __await__(self) -> Generator[Any, Any, Self]: + """Run the graph until it ends, and return the final result.""" + + async def _run(): + await self.graph_run + return self + + return _run().__await__() + + def __enter__(self) -> Self: + self.graph_run.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.graph_run.__exit__(exc_type, exc_val, exc_tb) + + def __aiter__( + self, + ) -> AsyncIterator[ + BaseNode[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + | End[MarkFinalResult[ResultDataT]] + ]: + return self + + async def __anext__( + self, + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + | End[MarkFinalResult[ResultDataT]] + ): + """Use the last returned node as the input to `Graph.next`.""" + return await self.graph_run.__anext__() diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 9f6b37222..4b64a5249 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -133,6 +133,7 @@ def run( state: StateT = None, deps: DepsT = None, infer_name: bool = True, + span: LogfireSpan | None = None, ) -> GraphRun[StateT, DepsT, T]: """Run the graph from a starting node until it ends. @@ -142,6 +143,7 @@ def run( state: The initial state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a new span will be created. Returns: The result type from ending the run and the history of the run. @@ -171,7 +173,13 @@ async def main(): self._infer_name(inspect.currentframe()) return GraphRun[StateT, DepsT, T]( - self, start_node, history=[], state=state, deps=deps, auto_instrument=self._auto_instrument + self, + start_node, + history=[], + state=state, + deps=deps, + auto_instrument=self._auto_instrument, + span=span, ) def run_sync( @@ -514,17 +522,18 @@ def __init__( state: StateT, deps: DepsT, auto_instrument: bool, + span: LogfireSpan | None = None, ): self.graph = graph self.history = history self.state = state self.deps = deps self._auto_instrument = auto_instrument + self._span = span self._next_node = first_node self._started: bool = False self._result: End[RunEndT] | None = None - self._span: LogfireSpan | None = None @property def is_ended(self): @@ -568,8 +577,10 @@ def __enter__(self) -> typing_extensions.Self: if self._started: raise exceptions.GraphRuntimeError('A GraphRun can only be started once.') - if self._auto_instrument: + if self._auto_instrument and self._span is None: self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) + + if self._span is not None: self._span.__enter__() self._started = True diff --git a/tests/test_agent.py b/tests/test_agent.py index fad13b0d9..5a0b4a868 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,7 +27,7 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import RunResult, Usage +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsNow, TestEnv @@ -279,7 +279,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) - assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] + assert result.graph_run.result.tool_name == 'final_result' assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] @@ -312,7 +312,7 @@ def test_result_tool_return_content_no_tool(): result = agent.run_sync('Hello') assert result.data == 0 - result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] + result.graph_run.result.tool_name = 'wrong' with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(result_tool_return_content='foobar') @@ -532,37 +532,38 @@ async def ret_a(x: str) -> str: # if we pass new_messages, system prompt is inserted before the message_history messages result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result2.data == snapshot('{"ret_a":"a-apple"}') + assert result2._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) + new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -580,36 +581,36 @@ async def ret_a(x: str) -> str: # so only one system prompt result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) # same as result2 except for datetimes - assert result3 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result3.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result3.data == snapshot('{"ret_a":"a-apple"}') + assert result3._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result3.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) @@ -664,63 +665,63 @@ async def ret_a(x: str) -> str: ) result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - data=Response(a=0), - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ], - ), - # second call, notice no repeated system prompt - ModelRequest( - parts=[ - UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ] - ), - ], - _new_message_index=5, - _result_tool_name='final_result', - _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ], + ), + # second call, notice no repeated system prompt + ModelRequest( + parts=[ + UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ), + ] + ) + assert result2.data == snapshot(Response(a=0)) + assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] + assert result2._result_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( diff --git a/tests/typed_agent.py b/tests/typed_agent.py index fdf9f1a25..dbbe04411 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -8,7 +8,7 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai.result import RunResult +from pydantic_ai.agent import AgentRun from pydantic_ai.tools import ToolDefinition @@ -139,7 +139,7 @@ async def result_validator_wrong(ctx: RunContext[int], result: str) -> str: def run_sync() -> None: result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) - assert_type(result, RunResult[str]) + assert_type(result, AgentRun[MyDeps, str]) assert_type(result.data, str) @@ -176,7 +176,7 @@ class Bar: def run_sync3() -> None: result = union_agent.run_sync('testing') - assert_type(result, RunResult[Union[Foo, Bar]]) + assert_type(result, AgentRun[None, Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar])