From 57905c0b9e93044b966516b5ab83221946989482 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:17:47 -0700 Subject: [PATCH 01/28] Add GraphRun class --- docs/graph.md | 11 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 650 ++++++++++-------- pydantic_ai_slim/pydantic_ai/agent.py | 420 +++++++---- pydantic_ai_slim/pydantic_ai/messages.py | 28 + .../pydantic_ai/models/__init__.py | 106 ++- pydantic_ai_slim/pydantic_ai/result.py | 185 ++--- pydantic_graph/README.md | 6 +- pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 180 ++++- pydantic_graph/pydantic_graph/nodes.py | 2 + pyproject.toml | 2 +- tests/graph/test_graph.py | 20 +- tests/graph/test_history.py | 10 +- tests/graph/test_mermaid.py | 6 +- tests/graph/test_state.py | 6 +- tests/models/test_gemini.py | 4 + tests/models/test_mistral.py | 2 +- tests/test_agent.py | 241 +++---- tests/test_streaming.py | 16 +- tests/test_usage_limits.py | 58 +- tests/typed_agent.py | 6 +- tests/typed_graph.py | 6 +- 22 files changed, 1151 insertions(+), 817 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index fa1b87343..9db4cbffe 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -156,11 +156,11 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! -print(result) +graph_run = fives_graph.run_sync(DivisibleBy5(4)) # (4)! +print(graph_run.result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in graph_run.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` @@ -464,8 +464,8 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(WriteEmail(), state=state) - print(email) + graph_run = await feedback_graph.run(WriteEmail(), state=state) + print(graph_run.result) """ Email( subject='Welcome to our tech blog!', @@ -606,6 +606,7 @@ async def main(): Ask(), Answer(question='what is 1 + 1?', answer='2'), Evaluate(answer='2'), + End(data='Well done, 1 + 1 = 2'), ] """ return diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 60a5b3f97..ea56822dd 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -2,8 +2,9 @@ import asyncio import dataclasses +import uuid from abc import ABC -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field @@ -21,9 +22,9 @@ exceptions, messages as _messages, models, - result, usage as _usage, ) +from .models import MarkFinalResult, ModelRequestParameters, StreamedResponse from .result import ResultDataT from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -56,21 +57,6 @@ ResultT = TypeVar('ResultT') -@dataclasses.dataclass -class MarkFinalResult(Generic[ResultDataT]): - """Marker class to indicate that the result is the final result. - - This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. - - It also avoids problems in the case where the result type is itself `None`, but is set. - """ - - data: ResultDataT - """The final result data.""" - tool_name: str | None - """Name of the final result tool, None if the result is a string.""" - - @dataclasses.dataclass class GraphAgentState: """State kept across the execution of the agent graph.""" @@ -123,7 +109,7 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N async def _get_first_message( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> _messages.ModelRequest: - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) ctx.state.message_history = history run_context.messages = history @@ -196,12 +182,12 @@ async def run( return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) -@dataclasses.dataclass -class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> StreamModelRequestNode[DepsT, NodeRunEndT]: - return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) +# @dataclasses.dataclass +# class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): +# async def run( +# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] +# ) -> StreamModelRequestNode[DepsT, NodeRunEndT]: +# return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) async def _prepare_request_parameters( @@ -210,7 +196,7 @@ async def _prepare_request_parameters( """Build tools and create an agent model.""" function_tool_defs: list[ToolDefinition] = [] - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) async def add_tool(tool: Tool[DepsT]) -> None: ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) @@ -222,7 +208,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: result_schema = ctx.deps.result_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_result=_allow_text_result(result_schema), + allow_text_result=allow_text_result(result_schema), result_tools=result_schema.tool_defs() if result_schema is not None else [], ) @@ -233,9 +219,69 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod request: _messages.ModelRequest + _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False) + _did_stream: bool = field(default=False, repr=False) + async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> HandleResponseNode[DepsT, NodeRunEndT]: + if self._result is not None: + return self._result + + if self._did_stream: + # `self._result` gets set when exiting the `stream` contextmanager, so hitting this + # means that the stream was started but not finished before `run()` was called + raise exceptions.AgentRunError('You must finish streaming before calling run()') + + return await self._make_request(ctx) + + @asynccontextmanager + async def stream( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> AsyncIterator[StreamedResponse]: + if self._did_stream: + raise exceptions.AgentRunError('stream() can only be called once') + + model_settings, model_request_parameters = await self._prepare_request(ctx) + with _logfire.span('model request', run_step=ctx.state.run_step) as span: + async with ctx.deps.model.request_stream( + ctx.state.message_history, model_settings, model_request_parameters + ) as streamed_response: + self._did_stream = True + ctx.state.usage.incr(_usage.Usage(), requests=1) + yield streamed_response + # In case the user didn't manually consume the full stream, ensure it is fully consumed here, + # otherwise usage won't be properly counted: + async for _ in streamed_response: + pass + model_response = streamed_response.get() + request_usage = streamed_response.usage() + span.set_attribute('response', model_response) + span.set_attribute('usage', request_usage) + + self._finish_handling(ctx, model_response, request_usage) + assert self._result is not None # this should be set by the previous line + + async def _make_request( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] + ) -> HandleResponseNode[DepsT, NodeRunEndT]: + if self._result is not None: + return self._result + + model_settings, model_request_parameters = await self._prepare_request(ctx) + with _logfire.span('model request', run_step=ctx.state.run_step) as span: + model_response, request_usage = await ctx.deps.model.request( + ctx.state.message_history, model_settings, model_request_parameters + ) + ctx.state.usage.incr(_usage.Usage(), requests=1) + span.set_attribute('response', model_response) + span.set_attribute('usage', request_usage) + + return self._finish_handling(ctx, model_response, request_usage) + + async def _prepare_request( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] + ) -> tuple[ModelSettings | None, ModelRequestParameters]: ctx.state.message_history.append(self.request) # Check usage @@ -245,67 +291,120 @@ async def run( # Increment run_step ctx.state.run_step += 1 + model_settings = merge_model_settings(ctx.deps.model_settings, None) with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step): model_request_parameters = await _prepare_request_parameters(ctx) + return model_settings, model_request_parameters - # Actually make the model request - model_settings = merge_model_settings(ctx.deps.model_settings, None) - with _logfire.span('model request') as span: - model_response, request_usage = await ctx.deps.model.request( - ctx.state.message_history, model_settings, model_request_parameters - ) - span.set_attribute('response', model_response) - span.set_attribute('usage', request_usage) - + def _finish_handling( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + response: _messages.ModelResponse, + usage: _usage.Usage, + ) -> HandleResponseNode[DepsT, NodeRunEndT]: # Update usage - ctx.state.usage.incr(request_usage, requests=1) + ctx.state.usage.incr(usage, requests=0) if ctx.deps.usage_limits: ctx.deps.usage_limits.check_tokens(ctx.state.usage) # Append the model response to state.message_history - ctx.state.message_history.append(model_response) - return HandleResponseNode(model_response) + ctx.state.message_history.append(response) + + # Set the `_result` attribute since we can't use `return` in an async iterator + self._result = HandleResponseNode(response) + + return self._result @dataclasses.dataclass class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): - """Process e response from a model, decide whether to end the run or make a new request.""" + """Process the response from a model, decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse + _stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) + _next_node: ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT] | None = field( + default=None, repr=False + ) + _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) + async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007 + async with self.stream(ctx): + pass + + # the stream should set `self._next_node` before it ends: + assert (next_node := self._next_node) is not None + return next_node + + @asynccontextmanager + async def stream( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: - texts: list[str] = [] - tool_calls: list[_messages.ToolCallPart] = [] - for part in self.model_response.parts: - if isinstance(part, _messages.TextPart): - # ignore empty content for text parts, see #437 - if part.content: - texts.append(part.content) - elif isinstance(part, _messages.ToolCallPart): - tool_calls.append(part) + stream = self._run_stream(ctx) + yield stream + + # Run the stream to completion if it was not finished: + async for _event in stream: + pass + + # Set the next node based on the final state of the stream + next_node = self._next_node + if isinstance(next_node, FinalResultNode): + handle_span.set_attribute('result', next_node.data) + handle_span.message = 'handle model response -> final result' + elif tool_responses := self._tool_responses: + # TODO: We could drop `self._tool_responses` if we drop this set_attribute + # I'm thinking it might be better to just create a span for the handling of each tool + # than to set an attribute here. + handle_span.set_attribute('tool_responses', tool_responses) + tool_responses_str = ' '.join(r.part_kind for r in tool_responses) + handle_span.message = f'handle model response -> {tool_responses_str}' + + async def _run_stream( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> AsyncIterator[_messages.HandleResponseEvent]: + if self._stream is None: + # Ensure that the stream is only run once + + async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: + texts: list[str] = [] + tool_calls: list[_messages.ToolCallPart] = [] + for part in self.model_response.parts: + if isinstance(part, _messages.TextPart): + # ignore empty content for text parts, see #437 + if part.content: + texts.append(part.content) + elif isinstance(part, _messages.ToolCallPart): + tool_calls.append(part) + else: + assert_never(part) + + # At the moment, we prioritize at least executing tool calls if they are present. + # In the future, we'd consider making this configurable at the agent or run level. + # This accounts for cases like anthropic returns that might contain a text response + # and a tool call response, where the text response just indicates the tool call will happen. + if tool_calls: + async for event in self._handle_tool_calls(ctx, tool_calls): + yield event + elif texts: + # No events are emitted during the handling of text responses, so we don't need to yield anything + self._next_node = await self._handle_text_response(ctx, texts) else: - assert_never(part) - - # At the moment, we prioritize at least executing tool calls if they are present. - # In the future, we'd consider making this configurable at the agent or run level. - # This accounts for cases like anthropic returns that might contain a text response - # and a tool call response, where the text response just indicates the tool call will happen. - if tool_calls: - return await self._handle_tool_calls_response(ctx, tool_calls, handle_span) - elif texts: - return await self._handle_text_response(ctx, texts, handle_span) - else: - raise exceptions.UnexpectedModelBehavior('Received empty model response') + raise exceptions.UnexpectedModelBehavior('Received empty model response') - async def _handle_tool_calls_response( + self._stream = _run_stream() + + async for event in self._stream: + yield event + + async def _handle_tool_calls( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], - handle_span: logfire_api.LogfireSpan, - ): + ) -> AsyncIterator[_messages.HandleResponseEvent]: result_schema = ctx.deps.result_schema # first look for the result tool call @@ -326,30 +425,28 @@ async def _handle_tool_calls_response( final_result = MarkFinalResult(result_data, call.tool_name) # Then build the other request parts based on end strategy - tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx) + tool_responses: list[_messages.ModelRequestPart] = self._tool_responses + async for event in process_function_tools( + tool_calls, final_result and final_result.tool_name, ctx, tool_responses + ): + yield event if final_result: - handle_span.set_attribute('result', final_result.data) - handle_span.message = 'handle model response -> final result' - return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses) + self._next_node = FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses) else: if tool_responses: - handle_span.set_attribute('tool_responses', tool_responses) - tool_responses_str = ' '.join(r.part_kind for r in tool_responses) - handle_span.message = f'handle model response -> {tool_responses_str}' parts.extend(tool_responses) - return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) + self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], texts: list[str], - handle_span: logfire_api.LogfireSpan, - ): + ) -> ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT]: result_schema = ctx.deps.result_schema text = '\n\n'.join(texts) - if _allow_text_result(result_schema): + if allow_text_result(result_schema): result_data_input = cast(NodeRunEndT, text) try: result_data = await _validate_result(result_data_input, ctx, None) @@ -357,8 +454,6 @@ async def _handle_text_response( ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: - handle_span.set_attribute('result', result_data) - handle_span.message = 'handle model response -> final result' return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None)) else: ctx.state.increment_retries(ctx.deps.max_result_retries) @@ -373,134 +468,134 @@ async def _handle_text_response( ) -@dataclasses.dataclass -class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): - """Make a request to the model using the last message in state.message_history (or a specified request).""" - - request: _messages.ModelRequest - _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = ( - field(default=None, repr=False) - ) - - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007 - if self._result is not None: - return self._result - - async with self.run_to_result(ctx) as final_node: - return final_node - - @asynccontextmanager - async def run_to_result( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: - result_schema = ctx.deps.result_schema - - ctx.state.message_history.append(self.request) - - # Check usage - if ctx.deps.usage_limits: - ctx.deps.usage_limits.check_before_request(ctx.state.usage) - - # Increment run_step - ctx.state.run_step += 1 - - with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): - model_request_parameters = await _prepare_request_parameters(ctx) - - # Actually make the model request - model_settings = merge_model_settings(ctx.deps.model_settings, None) - with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span: - async with ctx.deps.model.request_stream( - ctx.state.message_history, model_settings, model_request_parameters - ) as streamed_response: - ctx.state.usage.requests += 1 - model_req_span.set_attribute('response_type', streamed_response.__class__.__name__) - # We want to end the "model request" span here, but we can't exit the context manager - # in the traditional way - model_req_span.__exit__(None, None, None) - - with _logfire.span('handle model response') as handle_span: - received_text = False - - async for maybe_part_event in streamed_response: - if isinstance(maybe_part_event, _messages.PartStartEvent): - new_part = maybe_part_event.part - if isinstance(new_part, _messages.TextPart): - received_text = True - if _allow_text_result(result_schema): - handle_span.message = 'handle model response -> final result' - streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx) - self._result = End(streamed_run_result) - yield self._result - return - elif isinstance(new_part, _messages.ToolCallPart): - if result_schema is not None and (match := result_schema.find_tool([new_part])): - call, _ = match - handle_span.message = 'handle model response -> final result' - streamed_run_result = _build_streamed_run_result( - streamed_response, call.tool_name, ctx - ) - self._result = End(streamed_run_result) - yield self._result - return - else: - assert_never(new_part) - - tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] - parts: list[_messages.ModelRequestPart] = [] - model_response = streamed_response.get() - if not model_response.parts: - raise exceptions.UnexpectedModelBehavior('Received empty model response') - ctx.state.message_history.append(model_response) - - run_context = _build_run_context(ctx) - for p in model_response.parts: - if isinstance(p, _messages.ToolCallPart): - if tool := ctx.deps.function_tools.get(p.tool_name): - tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) - else: - parts.append(_unknown_tool(p.tool_name, ctx)) - - if received_text and not tasks and not parts: - # Can only get here if self._allow_text_result returns `False` for the provided result_schema - ctx.state.increment_retries(ctx.deps.max_result_retries) - self._result = StreamModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest( - parts=[ - _messages.RetryPromptPart( - content='Plain text responses are not permitted, please call one of the functions instead.', - ) - ] - ) - ) - yield self._result - return - - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) - parts.extend(task_results) - - next_request = _messages.ModelRequest(parts=parts) - if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - try: - ctx.state.increment_retries(ctx.deps.max_result_retries) - except: - # TODO: This is janky, so I think we should probably change it, but how? - ctx.state.message_history.append(next_request) - raise - - handle_span.set_attribute('tool_responses', parts) - tool_responses_str = ' '.join(r.part_kind for r in parts) - handle_span.message = f'handle model response -> {tool_responses_str}' - # the model_response should have been fully streamed by now, we can add its usage - streamed_response_usage = streamed_response.usage() - run_context.usage.incr(streamed_response_usage) - ctx.deps.usage_limits.check_tokens(run_context.usage) - self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request) - yield self._result - return +# @dataclasses.dataclass +# class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): +# """Make a request to the model using the last message in state.message_history (or a specified request).""" +# +# request: _messages.ModelRequest +# _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = ( +# field(default=None, repr=False) +# ) +# +# async def run( +# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] +# ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: +# if self._result is not None: +# return self._result +# +# async with self.run_to_result(ctx) as final_node: +# return final_node +# +# @asynccontextmanager +# async def run_to_result( +# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] +# ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: +# result_schema = ctx.deps.result_schema +# +# ctx.state.message_history.append(self.request) +# +# # Check usage +# if ctx.deps.usage_limits: +# ctx.deps.usage_limits.check_before_request(ctx.state.usage) +# +# # Increment run_step +# ctx.state.run_step += 1 +# +# with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): +# model_request_parameters = await _prepare_request_parameters(ctx) +# +# # Actually make the model request +# model_settings = merge_model_settings(ctx.deps.model_settings, None) +# with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span: +# async with ctx.deps.model.request_stream( +# ctx.state.message_history, model_settings, model_request_parameters +# ) as streamed_response: +# ctx.state.usage.requests += 1 +# model_req_span.set_attribute('response_type', streamed_response.__class__.__name__) +# # We want to end the "model request" span here, but we can't exit the context manager +# # in the traditional way +# model_req_span.__exit__(None, None, None) +# +# with _logfire.span('handle model response') as handle_span: +# received_text = False +# +# async for maybe_part_event in streamed_response: +# if isinstance(maybe_part_event, _messages.PartStartEvent): +# new_part = maybe_part_event.part +# if isinstance(new_part, _messages.TextPart): +# received_text = True +# if _allow_text_result(result_schema): +# handle_span.message = 'handle model response -> final result' +# streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx) +# self._result = End(streamed_run_result) +# yield self._result +# return +# elif isinstance(new_part, _messages.ToolCallPart): +# if result_schema is not None and (match := result_schema.find_tool([new_part])): +# call, _ = match +# handle_span.message = 'handle model response -> final result' +# streamed_run_result = _build_streamed_run_result( +# streamed_response, call.tool_name, ctx +# ) +# self._result = End(streamed_run_result) +# yield self._result +# return +# else: +# assert_never(new_part) +# +# tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] +# parts: list[_messages.ModelRequestPart] = [] +# model_response = streamed_response.get() +# if not model_response.parts: +# raise exceptions.UnexpectedModelBehavior('Received empty model response') +# ctx.state.message_history.append(model_response) +# +# run_context = _build_run_context(ctx) +# for p in model_response.parts: +# if isinstance(p, _messages.ToolCallPart): +# if tool := ctx.deps.function_tools.get(p.tool_name): +# tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) +# else: +# parts.append(_unknown_tool(p.tool_name, ctx)) +# +# if received_text and not tasks and not parts: +# # Can only get here if self._allow_text_result returns `False` for the provided result_schema +# ctx.state.increment_retries(ctx.deps.max_result_retries) +# self._result = StreamModelRequestNode[DepsT, NodeRunEndT]( +# _messages.ModelRequest( +# parts=[ +# _messages.RetryPromptPart( +# content='Plain text responses are not permitted, please call one of the functions instead.', +# ) +# ] +# ) +# ) +# yield self._result +# return +# +# with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): +# task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) +# parts.extend(task_results) +# +# next_request = _messages.ModelRequest(parts=parts) +# if any(isinstance(part, _messages.RetryPromptPart) for part in parts): +# try: +# ctx.state.increment_retries(ctx.deps.max_result_retries) +# except: +# # TODO: This is janky, so I think we should probably change it, but how? +# ctx.state.message_history.append(next_request) +# raise +# +# handle_span.set_attribute('tool_responses', parts) +# tool_responses_str = ' '.join(r.part_kind for r in parts) +# handle_span.message = f'handle model response -> {tool_responses_str}' +# # the model_response should have been fully streamed by now, we can add its usage +# streamed_response_usage = streamed_response.usage() +# run_context.usage.incr(streamed_response_usage) +# ctx.deps.usage_limits.check_tokens(run_context.usage) +# self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request) +# yield self._result +# return @dataclasses.dataclass @@ -532,7 +627,7 @@ async def run( return End(self.data) -def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: +def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: return RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, @@ -543,76 +638,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps ) -def _build_streamed_run_result( - result_stream: models.StreamedResponse, - result_tool_name: str | None, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> result.StreamedRunResult[DepsT, NodeRunEndT]: - new_message_index = ctx.deps.new_message_index - result_schema = ctx.deps.result_schema - run_span = ctx.deps.run_span - usage_limits = ctx.deps.usage_limits - messages = ctx.state.message_history - run_context = _build_run_context(ctx) - - async def on_complete(): - """Called when the stream has completed. - - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)] - parts = await _process_function_tools( - tool_calls, - result_tool_name, - ctx, - ) - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) - if parts: - messages.append(_messages.ModelRequest(parts)) - run_span.set_attribute('all_messages', messages) - - return result.StreamedRunResult[DepsT, NodeRunEndT]( - messages, - new_message_index, - usage_limits, - result_stream, - result_schema, - run_context, - ctx.deps.result_validators, - result_tool_name, - on_complete, - ) - - -async def _process_function_tools( +async def process_function_tools( tool_calls: list[_messages.ToolCallPart], result_tool_name: str | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> list[_messages.ModelRequestPart]: + output_parts: list[_messages.ModelRequestPart], +) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - """ - parts: list[_messages.ModelRequestPart] = [] - tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = [] + Because async iterators can't have return values, we use `parts` as an output argument. + """ stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early' result_schema = ctx.deps.result_schema # we rely on the fact that if we found a result, it's the first result tool in the last found_used_result_tool = False - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) + calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + call_index_to_event_id: dict[int, uuid.UUID] = {} for call in tool_calls: if call.tool_name == result_tool_name and not found_used_result_tool: found_used_result_tool = True - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', @@ -621,7 +671,7 @@ async def _process_function_tools( ) elif tool := ctx.deps.function_tools.get(call.tool_name): if stub_function_tools: - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', @@ -629,33 +679,47 @@ async def _process_function_tools( ) ) else: - tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name)) + event = _messages.FunctionToolCallEvent(call) + yield event + call_index_to_event_id[len(calls_to_run)] = event.call_id + calls_to_run.append((tool, call)) elif result_schema is not None and call.tool_name in result_schema.tools: # if tool_name is in _result_schema, it means we found a result tool but an error occurred in # validation, we don't add another part here if result_tool_name is not None: - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Result tool not used - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Result tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, ) + output_parts.append(part) else: - parts.append(_unknown_tool(call.tool_name, ctx)) + output_parts.append(_unknown_tool(call.tool_name, ctx)) + + if not calls_to_run: + return # Run all tool tasks in parallel - if tasks: - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks) - for result in task_results: - if isinstance(result, _messages.ToolReturnPart): - parts.append(result) - elif isinstance(result, _messages.RetryPromptPart): - parts.append(result) + results_by_index: dict[int, _messages.ModelRequestPart] = {} + with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]): + # TODO: Should we wrap each individual tool call in a dedicated span? + tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run] + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + result = task.result() + yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index]) + if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)): + results_by_index[index] = result else: assert_never(result) - return parts + + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(results_by_index): + output_parts.append(results_by_index[k]) def _unknown_tool( @@ -681,12 +745,12 @@ async def _validate_result( tool_call: _messages.ToolCallPart | None, ) -> T: for validator in ctx.deps.result_validators: - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) result_data = await validator.validate(result_data, tool_call, run_context) return result_data -def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: +def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: return result_schema is None or result_schema.allow_text_result @@ -758,17 +822,17 @@ def build_agent_graph( return graph -def build_agent_stream_graph( - name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]: - nodes = [ - StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], - StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], - ] - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]( - nodes=nodes, - name=name or 'Agent', - state_type=GraphAgentState, - run_end_type=result.StreamedRunResult[DepsT, result_type], - ) - return graph +# def build_agent_stream_graph( +# name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None +# ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]: +# nodes = [ +# StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], +# StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], +# ] +# graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]( +# nodes=nodes, +# name=name or 'Agent', +# state_type=GraphAgentState, +# run_end_type=result.StreamedRunResult[DepsT, result_type], +# ) +# return graph diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3501833d2..709e48a8f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -3,15 +3,16 @@ import asyncio import dataclasses import inspect -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Generator, Iterator, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from copy import deepcopy from types import FrameType from typing import Any, Callable, Generic, cast, final, overload import logfire_api -from typing_extensions import TypeVar, deprecated +from typing_extensions import Self, TypeVar, deprecated -from pydantic_graph import Graph, GraphRunContext, HistoryStep +from pydantic_graph import BaseNode, Graph, GraphRun, GraphRunContext from pydantic_graph.nodes import End from . import ( @@ -26,7 +27,8 @@ usage as _usage, ) from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export -from .result import ResultDataT +from .models import MarkFinalResult +from .result import ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -202,7 +204,7 @@ def __init__( self._register_tool(Tool(tool)) @overload - async def run( + def run( self, user_prompt: str, *, @@ -214,10 +216,10 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... @overload - async def run( + def run( self, user_prompt: str, *, @@ -229,21 +231,21 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... - async def run( + def run( self, user_prompt: str, *, + result_type: type[RunResultDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, - result_type: type[RunResultDataT] | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: + ) -> AgentRun[AgentDepsT, ResultDataT]: """Run the agent with a user prompt in async mode. Example: @@ -305,53 +307,45 @@ async def main(): model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() - with _logfire.span( + # Build the deps object for the graph + run_span = _logfire.span( '{agent_name} run {prompt=}', prompt=user_prompt, agent=self, model_name=model_used.model_name if model_used else 'no-model', agent_name=self.name or 'agent', - ) as run_span: - # Build the deps object for the graph - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, - ) - - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) + ) + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + result_schema=result_schema, + result_tools=self._result_schema.tool_defs() if self._result_schema else [], + result_validators=result_validators, + function_tools=self._function_tools, + run_span=run_span, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + ) - # Actually run - end_result, _ = await graph.run( + # Actually run + return AgentRun( + graph.run( start_node, state=state, deps=graph_deps, infer_name=False, + span=run_span, ) - - # Build final run result - # We don't do any advanced checking if the data is actually from a final result or not - return result.RunResult( - state.message_history, - new_message_index, - end_result.data, - end_result.tool_name, - state.usage, ) @overload @@ -366,7 +360,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... @overload def run_sync( @@ -381,7 +375,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRun[AgentDepsT, ResultDataT]: ... def run_sync( self, @@ -395,7 +389,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: + ) -> AgentRun[AgentDepsT, ResultDataT]: """Run the agent with a user prompt synchronously. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. @@ -474,7 +468,7 @@ def run_stream( ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ... @asynccontextmanager - async def run_stream( + async def run_stream( # noqa self, user_prompt: str, *, @@ -520,90 +514,98 @@ async def main(): # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) - model_used = self._get_model(model) - - deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 - result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type) - - # Build the graph - graph = self._build_stream_graph(result_type) - - # Build the initial state - graph_state = _agent_graph.GraphAgentState( - message_history=message_history[:] if message_history else [], - usage=usage or _usage.Usage(), - retries=0, - run_step=0, - ) - # We consider it a user error if a user tries to restrict the result type while having a result validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators) - - # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent - # runs. Requires some changes to `Tool` to make them copyable though. - for v in self._function_tools.values(): - v.current_retry = 0 - - model_settings = merge_model_settings(self.model_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - - with _logfire.span( - '{agent_name} run stream {prompt=}', - prompt=user_prompt, - agent=self, - model_name=model_used.model_name if model_used else 'no-model', - agent_name=self.name or 'agent', - ) as run_span: - # Build the deps object for the graph - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, - ) - - start_node = _agent_graph.StreamUserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) - - # Actually run - node = start_node - history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = [] + yielded = False + with self.run( + user_prompt, + result_type=result_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=False, + ) as agent_run: + first_node = await agent_run.__anext__() + assert isinstance(first_node, _agent_graph.ModelRequestNode) # the first node should be a request node + node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node) while True: - if isinstance(node, _agent_graph.StreamModelRequestNode): - node = cast( - _agent_graph.StreamModelRequestNode[ - AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT] - ], - node, - ) - async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r: - if isinstance(r, End): - yield r.data + if isinstance(node, _agent_graph.ModelRequestNode): + node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) + graph_ctx = agent_run.graph_ctx() + async with node.stream(graph_ctx) as streamed_response: + + async def stream_to_final( + s: models.StreamedResponse, + ) -> MarkFinalResult[models.StreamedResponse] | None: + result_schema = graph_ctx.deps.result_schema + async for maybe_part_event in streamed_response: + if isinstance(maybe_part_event, _messages.PartStartEvent): + new_part = maybe_part_event.part + if isinstance(new_part, _messages.TextPart): + if _agent_graph.allow_text_result(result_schema): + return MarkFinalResult(s, None) + elif isinstance(new_part, _messages.ToolCallPart): + if result_schema is not None and (match := result_schema.find_tool([new_part])): + call, _ = match + return MarkFinalResult(s, call.tool_name) + return None + + final_result_details = await stream_to_final(streamed_response) + if final_result_details is not None: + if yielded: + raise exceptions.AgentRunError('Agent run produced final results') + yielded = True + + messages = graph_ctx.state.message_history.copy() + + async def on_complete() -> None: + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + + parts: list[_messages.ModelRequestPart] = [] + async for _event in _agent_graph.process_function_tools( + tool_calls, + final_result_details.tool_name, + graph_ctx, + parts, + ): + pass + # TODO: Should we do something here related to the retry count? + # Maybe we should move the incrementing of the retry count to where we actually make a request? + # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + # ctx.state.increment_retries(ctx.deps.max_result_retries) + if parts: + messages.append(_messages.ModelRequest(parts)) + + yield StreamedRunResult( + messages, + graph_ctx.deps.new_message_index, + graph_ctx.deps.usage_limits, + streamed_response, + graph_ctx.deps.result_schema, + _agent_graph.build_run_context(graph_ctx), + graph_ctx.deps.result_validators, + final_result_details.tool_name, + on_complete, + ) break - assert not isinstance(node, End) # the previous line should be hit first - node = await graph.next( - node, - history, - state=graph_state, - deps=graph_deps, - infer_name=False, - ) + next_node = await agent_run.next(node) + if not isinstance(next_node, BaseNode): + raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here') + node = cast(BaseNode[Any, Any, Any], next_node) + + if not yielded: + raise exceptions.AgentRunError('Agent run finished without producing a final result') @contextmanager def override( @@ -1039,14 +1041,9 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: def _build_graph( self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: + ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], models.MarkFinalResult[Any]]: return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) - def _build_stream_graph( - self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: - return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type) - def _prepare_result_schema( self, result_type: type[RunResultDataT] | None ) -> _result.ResultSchema[RunResultDataT] | None: @@ -1058,3 +1055,150 @@ def _prepare_result_schema( ) else: return self._result_schema # pyright: ignore[reportReturnType] + + +@dataclasses.dataclass +class AgentRun(Generic[AgentDepsT, ResultDataT]): + graph_run: GraphRun[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], models.MarkFinalResult[ResultDataT] + ] + + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + if result_tool_return_content is not None: + return self._set_result_tool_return(result_tool_return_content) + else: + return self.graph_run.state.message_history + + def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.all_messages(result_tool_return_content=result_tool_return_content) + ) + + @property + def _new_message_index(self) -> int: + return self.graph_run.deps.new_message_index + + def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] + + def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the new messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.new_messages(result_tool_return_content=result_tool_return_content) + ) + + def usage(self) -> _usage.Usage: + return self.graph_run.state.usage + + def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: + """Set return content for the result tool. + + Useful if you want to continue the conversation and want to set the response to the result tool call. + """ + if not self.result.tool_name: + raise ValueError('Cannot set result tool return content when the return type is `str`.') + messages = deepcopy(self.graph_run.state.message_history) + last_message = messages[-1] + for part in last_message.parts: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name: + part.content = return_content + return messages + raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.') + + @property + def is_ended(self) -> bool: + return self.graph_run.is_ended + + @property + def result(self) -> models.MarkFinalResult[ResultDataT]: + return self.graph_run.result + + @property + def _result_tool_name(self) -> str | None: + return self.graph_run.result.tool_name + + @property + def data(self) -> ResultDataT: + return self.result.data + + async def next( + self, + node: BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + models.MarkFinalResult[ResultDataT], + ], + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + models.MarkFinalResult[ResultDataT], + ] + | End[models.MarkFinalResult[ResultDataT]] + ): + return await self.graph_run.next(node) + + def graph_ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: + return GraphRunContext(self.graph_run.state, self.graph_run.deps) + + def __await__(self) -> Generator[Any, Any, Self]: + """Run the graph until it ends, and return the final result.""" + + async def _run(): + await self.graph_run + return self + + return _run().__await__() + + def __enter__(self) -> Self: + self.graph_run.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.graph_run.__exit__(exc_type, exc_val, exc_tb) + + def __aiter__( + self, + ) -> AsyncIterator[ + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + models.MarkFinalResult[ResultDataT], + ] + | End[models.MarkFinalResult[ResultDataT]] + ]: + return self + + async def __anext__( + self, + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + models.MarkFinalResult[ResultDataT], + ] + | End[models.MarkFinalResult[ResultDataT]] + ): + """Use the last returned node as the input to `Graph.next`.""" + return await self.graph_run.__anext__() diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index d3001bf52..9f39e288b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import uuid from dataclasses import dataclass, field, replace from datetime import datetime from typing import Annotated, Any, Literal, Union, cast, overload @@ -445,3 +446,30 @@ class PartDeltaEvent: ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] """An event in the model response stream, either starting a new part or applying a delta to an existing one.""" + + +@dataclass +class FunctionToolCallEvent: + """An event indicating the start to a call to a function tool.""" + + part: ToolCallPart + """The (function) tool call to make.""" + call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False) + """An ID used to match the call to its result.""" + event_kind: Literal['function_tool_call'] = field(default='function_tool_call', repr=False) + """Event type identifier, used as a discriminator.""" + + +@dataclass +class FunctionToolResultEvent: + """An event indicating the result of a function tool call.""" + + result: ToolReturnPart | RetryPromptPart + """The result of the call to the function tool.""" + call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False) + """An ID used to match the result to its original call.""" + event_kind: Literal['function_tool_result'] = field(default='function_tool_result', repr=False) + """Event type identifier, used as a discriminator.""" + + +HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index bba9898a7..e66591574 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -12,11 +12,13 @@ from dataclasses import dataclass, field from datetime import datetime from functools import cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic import httpx -from typing_extensions import Literal +import logfire_api +from typing_extensions import Literal, TypeVar +from .. import _utils, messages as _messages from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent @@ -26,6 +28,7 @@ if TYPE_CHECKING: from ..tools import ToolDefinition +_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') KnownModelName = Literal[ 'anthropic:claude-3-5-haiku-latest', @@ -164,6 +167,23 @@ `KnownModelName` is provided as a concise way to specify a model. """ +ResultDataT = TypeVar('ResultDataT', covariant=True) + + +@dataclass +class MarkFinalResult(Generic[ResultDataT]): + """Marker class to indicate that the result is the final result. + + This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. + + It also avoids problems in the case where the result type is itself `None`, but is set. + """ + + data: ResultDataT + """The final result data.""" + tool_name: str | None + """Name of the final result tool, None if the result is a string.""" + @dataclass class ModelRequestParameters: @@ -241,11 +261,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noinspection PyUnreachableCode yield - def get(self) -> ModelResponse: - """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" - return ModelResponse( - parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp() - ) + @abstractmethod + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + raise NotImplementedError() def model_name(self) -> str: """Get the model name of the response.""" @@ -255,10 +274,75 @@ def usage(self) -> Usage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage - @abstractmethod - def timestamp(self) -> datetime: - """Get the timestamp of the response.""" - raise NotImplementedError() + def get(self) -> ModelResponse: + """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + return ModelResponse( + parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp() + ) + + async def stream_events(self) -> AsyncIterator[ModelResponseStreamEvent]: + return self.__aiter__() + + async def stream_debounced_events( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[list[ModelResponseStreamEvent]]: + async with _utils.group_by_temporal(self, debounce_by) as group_iter: + async for items in group_iter: + yield items + + async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: + async def _stream_structured_ungrouped() -> AsyncIterator[None]: + # yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting + async for _event in self: + yield None + + async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter: + async for _items in group_iter: + yield self.get() # current state of the response + + async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: + # Define a "merged" version of the iterator that will yield items that have already been retrieved + # and items that we receive while streaming. We define a dedicated async iterator for this so we can + # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: + # yields tuples of (text_content, part_index) + # we don't currently make use of the part_index, but in principle this may be useful + # so we retain it here for now to make possible future refactors simpler + msg = self.get() + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.TextPart) and part.content: + yield part.content, i + + async for event in self: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index + elif ( + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + # Note: we are currently just dropping the part index on the group here + yield ''.join([content for content, _ in items]) + + if delta: + async for text in _stream_text_deltas(): + yield text + else: + # a quick benchmark shows it's faster to build up a string with concat when we're + # yielding at each step + deltas: list[str] = [] + async for text in _stream_text_deltas(): + deltas.append(text) + yield ''.join(deltas) ALLOW_MODEL_REQUESTS = True diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 8fb3fb974..bd33d3488 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,8 +1,7 @@ from __future__ import annotations as _annotations -from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from copy import deepcopy +from copy import copy from dataclasses import dataclass, field from datetime import datetime from typing import Generic, Union, cast @@ -10,11 +9,11 @@ import logfire_api from typing_extensions import TypeVar -from . import _result, _utils, exceptions, messages as _messages, models +from . import _result, exceptions, messages as _messages, models from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult' +__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc' T = TypeVar('T') @@ -53,15 +52,34 @@ @dataclass -class _BaseRunResult(ABC, Generic[ResultDataT]): - """Base type for results. - - You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`. - """ +class StreamedRunResult(Generic[AgentDepsT, ResultDataT]): + """Result of a streamed run that returns structured data via a tool call.""" _all_messages: list[_messages.ModelMessage] _new_message_index: int + _usage_limits: UsageLimits | None + _stream_response: models.StreamedResponse + _result_schema: _result.ResultSchema[ResultDataT] | None + _run_ctx: RunContext[AgentDepsT] + _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] + _result_tool_name: str | None + _on_complete: Callable[[], Awaitable[None]] + + _initial_run_ctx_usage: Usage = field(init=False) + is_complete: bool = field(default=False, init=False) + """Whether the stream has all been received. + + This is set to `True` when one of + [`stream`][pydantic_ai.result.StreamedRunResult.stream], + [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], + [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or + [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. + """ + + def __post_init__(self): + self._initial_run_ctx_usage = copy(self._run_ctx.usage) + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return the history of _messages. @@ -127,78 +145,6 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> self.new_messages(result_tool_return_content=result_tool_return_content) ) - @abstractmethod - def usage(self) -> Usage: - raise NotImplementedError() - - -@dataclass -class RunResult(_BaseRunResult[ResultDataT]): - """Result of a non-streamed run.""" - - data: ResultDataT - """Data from the final response in the run.""" - _result_tool_name: str | None - _usage: Usage - - def usage(self) -> Usage: - """Return the usage of the whole run.""" - return self._usage - - def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: - """Return the history of _messages. - - Args: - result_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the result tool call if you want to continue - the conversation and want to set the response to the result tool call. If `None`, the last message will - not be modified. - - Returns: - List of messages. - """ - if result_tool_return_content is not None: - return self._set_result_tool_return(result_tool_return_content) - else: - return self._all_messages - - def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: - """Set return content for the result tool. - - Useful if you want to continue the conversation and want to set the response to the result tool call. - """ - if not self._result_tool_name: - raise ValueError('Cannot set result tool return content when the return type is `str`.') - messages = deepcopy(self._all_messages) - last_message = messages[-1] - for part in last_message.parts: - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: - part.content = return_content - return messages - raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') - - -@dataclass -class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]): - """Result of a streamed run that returns structured data via a tool call.""" - - _usage_limits: UsageLimits | None - _stream_response: models.StreamedResponse - _result_schema: _result.ResultSchema[ResultDataT] | None - _run_ctx: RunContext[AgentDepsT] - _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] - _result_tool_name: str | None - _on_complete: Callable[[], Awaitable[None]] - is_complete: bool = field(default=False, init=False) - """Whether the stream has all been received. - - This is set to `True` when one of - [`stream`][pydantic_ai.result.StreamedRunResult.stream], - [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], - [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or - [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. - """ - async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: """Stream the response as an async iterable. @@ -214,6 +160,7 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu Returns: An async iterable of the response data. """ + self._stream_response.stream_structured(debounce_by=debounce_by) async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): result = await self.validate_structured_result(structured_message, allow_partial=not is_last) yield result @@ -234,61 +181,17 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = if self._result_schema and not self._result_schema.allow_text_result: raise exceptions.UserError('stream_text() can only be used with text responses') - usage_checking_stream = _get_usage_checking_stream_response( - self._stream_response, self._usage_limits, self.usage - ) - - # Define a "merged" version of the iterator that will yield items that have already been retrieved - # and items that we receive while streaming. We define a dedicated async iterator for this so we can - # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. - async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: - # if the response currently has any parts with content, yield those before streaming - msg = self._stream_response.get() - for i, part in enumerate(msg.parts): - if isinstance(part, _messages.TextPart) and part.content: - yield part.content, i - - async for event in usage_checking_stream: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index - elif ( - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index - - async def _stream_text_deltas() -> AsyncIterator[str]: - async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: - async for items in group_iter: - yield ''.join([content for content, _ in items]) - with _logfire.span('response stream text') as lf_span: if delta: - async for text in _stream_text_deltas(): + async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): yield text else: - # a quick benchmark shows it's faster to build up a string with concat when we're - # yielding at each step - deltas: list[str] = [] combined_validated_text = '' - async for text in _stream_text_deltas(): - deltas.append(text) - combined_text = ''.join(deltas) - combined_validated_text = await self._validate_text_result(combined_text) + async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + combined_validated_text = await self._validate_text_result(text) yield combined_validated_text - lf_span.set_attribute('combined_text', combined_validated_text) - await self._marked_completed( - _messages.ModelResponse( - parts=[_messages.TextPart(combined_validated_text)], - model_name=self._stream_response.model_name(), - ) - ) + await self._marked_completed(self._stream_response.get()) async def stream_structured( self, *, debounce_by: float | None = 0.1 @@ -303,10 +206,6 @@ async def stream_structured( Returns: An async iterable of the structured response message and whether that is the last message. """ - usage_checking_stream = _get_usage_checking_stream_response( - self._stream_response, self._usage_limits, self.usage - ) - with _logfire.span('response stream structured') as lf_span: # if the message currently has any parts with content, yield before streaming msg = self._stream_response.get() @@ -315,15 +214,14 @@ async def stream_structured( yield msg, False break - async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for _events in group_iter: - msg = self._stream_response.get() - yield msg, False - msg = self._stream_response.get() - yield msg, True - # TODO: Should this now be `final_response` instead of `structured_response`? - lf_span.set_attribute('structured_response', msg) - await self._marked_completed(msg) + async for msg in self._stream_response.stream_structured(debounce_by=debounce_by): + yield msg, False + + msg = self._stream_response.get() + yield msg, True + # TODO: Should this now be `final_response` instead of `structured_response`? + lf_span.set_attribute('structured_response', msg) + await self._marked_completed(msg) async def get_data(self) -> ResultDataT: """Stream the whole response, validate and return it.""" @@ -333,6 +231,7 @@ async def get_data(self) -> ResultDataT: async for _ in usage_checking_stream: pass + message = self._stream_response.get() await self._marked_completed(message) return await self.validate_structured_result(message) @@ -343,7 +242,7 @@ def usage(self) -> Usage: !!! note This won't return the full usage until the stream is finished. """ - return self._run_ctx.usage + self._stream_response.usage() + return self._initial_run_ctx_usage + self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 15a4062e0..3e4ffb24d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -50,10 +50,10 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(DivisibleBy5(4)) -print(result) +graph_run = fives_graph.run_sync(DivisibleBy5(4)) +print(graph_run.result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in graph_run.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index d4c6074e1..f5f2a01c0 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,10 +1,11 @@ from .exceptions import GraphRuntimeError, GraphSetupError -from .graph import Graph +from .graph import Graph, GraphRun from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', + 'GraphRun', 'BaseNode', 'End', 'GraphRunContext', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a670c3d39..710cfedee 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -3,17 +3,17 @@ import asyncio import inspect import types -from collections.abc import Sequence +from collections.abc import AsyncIterator, Generator, Sequence from contextlib import ExitStack from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path from time import perf_counter from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar import logfire_api import pydantic import typing_extensions +from logfire_api import LogfireSpan from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT @@ -30,7 +30,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) -__all__ = ('Graph',) +__all__ = ('Graph', 'GraphRun') _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') @@ -126,14 +126,15 @@ def __init__( self._validate_edges() - async def run( + def run( self: Graph[StateT, DepsT, T], start_node: BaseNode[StateT, DepsT, T], *, state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + span: LogfireSpan | None = None, + ) -> GraphRun[StateT, DepsT, T]: """Run the graph from a starting node until it ends. Args: @@ -142,6 +143,7 @@ async def run( state: The initial state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a new span will be created. Returns: The result type from ending the run and the history of the run. @@ -153,50 +155,32 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) - print(len(history)) + print(len(graph_run.history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) - print(len(history)) + print(len(graph_run.history)) #> 5 ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - history: list[HistoryStep[StateT, T]] = [] - with ExitStack() as stack: - run_span: logfire_api.LogfireSpan | None = None - if self._auto_instrument: - run_span = stack.enter_context( - _logfire.span( - '{graph_name} run {start=}', - graph_name=self.name or 'graph', - start=start_node, - ) - ) - - next_node = start_node - while True: - next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False) - if isinstance(next_node, End): - history.append(EndStep(result=next_node)) - if run_span is not None: - run_span.set_attribute('history', history) - return next_node.data, history - elif not isinstance(next_node, BaseNode): - if TYPE_CHECKING: - typing_extensions.assert_never(next_node) - else: - raise exceptions.GraphRuntimeError( - f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' - ) + return GraphRun[StateT, DepsT, T]( + self, + start_node, + history=[], + state=state, + deps=deps, + auto_instrument=self._auto_instrument, + span=span, + ) def run_sync( self: Graph[StateT, DepsT, T], @@ -205,7 +189,7 @@ def run_sync( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + ) -> GraphRun[StateT, DepsT, T]: """Run the graph synchronously. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. @@ -266,6 +250,17 @@ async def next( history.append( NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) ) + + if isinstance(next_node, End): + history.append(EndStep(result=next_node)) + elif not isinstance(next_node, BaseNode): + if TYPE_CHECKING: + typing_extensions.assert_never(next_node) + else: + raise exceptions.GraphRuntimeError( + f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' + ) + return next_node def dump_history( @@ -510,3 +505,114 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: if item is self: self.name = name return + + +class GraphRun(Generic[StateT, DepsT, RunEndT]): + """A stateful run of a graph. + + After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. + """ + + def __init__( + self, + graph: Graph[StateT, DepsT, RunEndT], + first_node: BaseNode[StateT, DepsT, RunEndT], + *, + history: list[HistoryStep[StateT, RunEndT]], + state: StateT, + deps: DepsT, + auto_instrument: bool, + span: LogfireSpan | None = None, + ): + self.graph = graph + self.history = history + self.state = state + self.deps = deps + self._auto_instrument = auto_instrument + self._span = span + + self._next_node = first_node + self._started: bool = False + self._result: End[RunEndT] | None = None + + @property + def is_ended(self) -> bool: + return self._result is not None + + @property + def result(self) -> RunEndT: + if self._result is None: + if self._started: + raise exceptions.GraphRuntimeError( + 'This GraphRun has not yet ended. Continue iterating with `async for` or `GraphRun.next`' + ' to complete the run before accessing the result.' + ) + else: + raise exceptions.GraphRuntimeError( + 'This GraphRun has not been started. Did you forget to `await` the run?' + ) + return self._result.data + + async def next( + self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] + ) -> BaseNode[StateT, DepsT, T] | End[T]: + """Note: this method behaves very similarly to an async generator's `asend` method.""" + if not self._started: + raise exceptions.GraphRuntimeError( + 'You must enter the GraphRun as a contextmanager before you can call `next` on it.' + ) + + history = self.history + state = self.state + deps = self.deps + + next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) + + if isinstance(next_node, End): + self._result = next_node + else: + self._next_node = next_node + return next_node + + def __await__(self) -> Generator[Any, Any, typing_extensions.Self]: + """Run the graph until it ends, and return the final result.""" + + async def _run() -> typing_extensions.Self: + with self: + async for _next_node in self: + pass + + return self + + return _run().__await__() + + def __enter__(self) -> typing_extensions.Self: + if self._started: + raise exceptions.GraphRuntimeError('A GraphRun can only be started once.') + + if self._auto_instrument and self._span is None: + self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) + + if self._span is not None: + self._span.__enter__() + + self._started = True + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._span is not None: + self._span.__exit__(exc_type, exc_val, exc_tb) + self._span = None # make it more obvious if you try to use it after exiting + + def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: + return self + + async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: + """Use the last returned node as the input to `Graph.next`.""" + if self._result: + raise StopAsyncIteration + if not self._started: + raise exceptions.GraphRuntimeError( + 'You must enter the GraphRun as a contextmanager before you can iterate over it.' + ) + return await self.next(self._next_node) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index b43391ffe..c50a63c21 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -28,6 +28,8 @@ class GraphRunContext(Generic[StateT, DepsT]): """Context for a graph.""" + # TODO: It would be nice to get rid of this struct and just pass both these things around... + state: StateT """The state of the graph.""" deps: DepsT diff --git a/pyproject.toml b/pyproject.toml index f8eaa9b0e..7d6957a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -# ignore-words-list = '' + ignore-words-list = 'asend' diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index ebd254a37..4668bc2c6 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - result, history = await my_graph.run(Float2String(3.14)) + graph_run = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == 8 + assert graph_run.result == 8 assert my_graph.name == 'my_graph' - assert history == snapshot( + assert graph_run.history == snapshot( [ NodeStep( state=None, @@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(Float2String(3.14159)) + graph_run = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == 42 - assert history == snapshot( + assert graph_run.result == 42 + assert graph_run.history == snapshot( [ NodeStep( state=None, @@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) - assert [e.data_snapshot() for e in history] == snapshot( + assert [e.data_snapshot() for e in graph_run.history] == snapshot( [ Float2String(input_data=3.14159), String2Length(input_data='3.14159'), @@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: return End(123) g = Graph(nodes=(Foo, Bar)) - result, history = await g.run(Foo(), deps=Deps(1, 2)) + graph_run = await g.run(Foo(), deps=Deps(1, 2)) - assert result == 123 - assert history == snapshot( + assert graph_run.result == 123 + assert graph_run.history == snapshot( [ NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 2508a5347..bcd8dca19 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -46,16 +46,16 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ], ) async def test_dump_load_history(graph: Graph[MyState, None, int]): - result, history = await graph.run(Foo(), state=MyState(1, '')) - assert result == snapshot(4) - assert history == snapshot( + graph_run = await graph.run(Foo(), state=MyState(1, '')) + assert graph_run.result == snapshot(4) + assert graph_run.history == snapshot( [ NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) - history_json = graph.dump_history(history) + history_json = graph.dump_history(graph_run.history) assert json.loads(history_json) == snapshot( [ { @@ -76,7 +76,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): ] ) history_loaded = graph.load_history(history_json) - assert history == history_loaded + assert graph_run.history == history_loaded custom_history = [ { diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f76d93cd..041fe6027 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg async def test_run_graph(): - result, history = await graph1.run(Foo()) - assert result is None - assert history == snapshot( + graph_run = await graph1.run(Foo()) + assert graph_run.result is None + assert graph_run.history == snapshot( [ NodeStep( state=None, diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index fbb570cf0..8c59667ae 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - result, history = await graph.run(Foo(), state=state) - assert result == snapshot('x=2 y=y') - assert history == snapshot( + graph_run = await graph.run(Foo(), state=state) + assert graph_run.result == snapshot('x=2 y=y') + assert graph_run.history == snapshot( [ NodeStep( state=MyState(x=2, y=''), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index e34d3f8ef..e335efe9b 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -790,6 +790,10 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m) + @agent.tool_plain() + def get_location(loc_name: str) -> str: + return f'Location for {loc_name}' + async with agent.run_stream('Hello') as result: data = await result.get_data() diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index ae5eb7b30..5f5e46eeb 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1633,7 +1633,7 @@ async def get_location(loc_name: str) -> str: ModelResponse( parts=[TextPart(content='final response')], model_name='mistral-large-latest', - timestamp=IsNow(tz=timezone.utc), + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), ] ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 1a1091959..44de33e91 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,7 +27,7 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import RunResult, Usage +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsNow, TestEnv @@ -279,7 +279,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) - assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] + assert result.graph_run.result.tool_name == 'final_result' assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] @@ -312,7 +312,7 @@ def test_result_tool_return_content_no_tool(): result = agent.run_sync('Hello') assert result.data == 0 - result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] + result.graph_run.result.tool_name = 'wrong' with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(result_tool_return_content='foobar') @@ -534,37 +534,38 @@ async def ret_a(x: str) -> str: # if we pass new_messages, system prompt is inserted before the message_history messages result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result2.data == snapshot('{"ret_a":"a-apple"}') + assert result2._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) + new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -582,36 +583,36 @@ async def ret_a(x: str) -> str: # so only one system prompt result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) # same as result2 except for datetimes - assert result3 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result3.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result3.data == snapshot('{"ret_a":"a-apple"}') + assert result3._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result3.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) @@ -666,63 +667,63 @@ async def ret_a(x: str) -> str: ) result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - data=Response(a=0), - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ], - ), - # second call, notice no repeated system prompt - ModelRequest( - parts=[ - UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ] - ), - ], - _new_message_index=5, - _result_tool_name='final_result', - _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ], + ), + # second call, notice no repeated system prompt + ModelRequest( + parts=[ + UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ), + ] + ) + assert result2.data == snapshot(Response(a=0)) + assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] + assert result2._result_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1725c4d36..f95be4c13 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -335,14 +335,18 @@ async def test_call_tool_wrong_name(): async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {0: DeltaToolCall(name='foobar', json_args='{}')} - agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) + agent = Agent( + FunctionModel(stream_function=stream_structured_function), + result_type=tuple[str, int], + retries=0, + ) @agent.tool_plain async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): async with agent.run_stream('hello'): pass # pragma: no cover @@ -354,14 +358,6 @@ async def ret_a(x: str) -> str: # pragma: no cover model_name='function:stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), - ModelRequest( - parts=[ - RetryPromptPart( - content="Unknown tool name: 'foobar'. Available tools: ret_a, final_result", - timestamp=IsNow(tz=timezone.utc), - ) - ] - ), ] ) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index e1d0234e0..ba00a3f01 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -76,34 +76,38 @@ async def test_streamed_text_limits() -> None: async def ret_a(x: str) -> str: return f'{x}-apple' - async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: - assert test_agent.name == 'test_agent' - assert not result.is_complete - assert result.all_messages() == snapshot( - [ - ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ] - ) - assert result.usage() == snapshot( - Usage( - requests=2, - request_tokens=103, - response_tokens=5, - total_tokens=108, + succeeded = False + + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') + ): + async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: + assert test_agent.name == 'test_agent' + assert not result.is_complete + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ] ) - ) - with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') - ): - await result.get_data() + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=5, + total_tokens=108, + ) + ) + succeeded = True + + assert succeeded def test_usage_so_far() -> None: diff --git a/tests/typed_agent.py b/tests/typed_agent.py index fdf9f1a25..dbbe04411 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -8,7 +8,7 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai.result import RunResult +from pydantic_ai.agent import AgentRun from pydantic_ai.tools import ToolDefinition @@ -139,7 +139,7 @@ async def result_validator_wrong(ctx: RunContext[int], result: str) -> str: def run_sync() -> None: result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) - assert_type(result, RunResult[str]) + assert_type(result, AgentRun[MyDeps, str]) assert_type(result.data, str) @@ -176,7 +176,7 @@ class Bar: def run_sync3() -> None: result = union_agent.run_sync('testing') - assert_type(result, RunResult[Union[Foo, Bar]]) + assert_type(result, AgentRun[None, Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar]) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index d0b6a02b7..deba4dd45 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -109,6 +109,6 @@ def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] - answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) - assert_type(answer, int) - assert_type(history, list[HistoryStep[MyState, int]]) + graph_run = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(graph_run.result, int) + assert_type(graph_run.history, list[HistoryStep[MyState, int]]) From 1a093789faf3025e226ec48eaefdcf120be26d14 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:31:31 -0700 Subject: [PATCH 02/28] Minor improvements --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 154 ------------------ .../pydantic_ai/models/__init__.py | 18 +- pydantic_graph/pydantic_graph/graph.py | 7 +- 3 files changed, 11 insertions(+), 168 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ea56822dd..5d1421915 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -182,14 +182,6 @@ async def run( return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) -# @dataclasses.dataclass -# class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): -# async def run( -# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] -# ) -> StreamModelRequestNode[DepsT, NodeRunEndT]: -# return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) - - async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: @@ -468,136 +460,6 @@ async def _handle_text_response( ) -# @dataclasses.dataclass -# class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): -# """Make a request to the model using the last message in state.message_history (or a specified request).""" -# -# request: _messages.ModelRequest -# _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = ( -# field(default=None, repr=False) -# ) -# -# async def run( -# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] -# ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: -# if self._result is not None: -# return self._result -# -# async with self.run_to_result(ctx) as final_node: -# return final_node -# -# @asynccontextmanager -# async def run_to_result( -# self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] -# ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: -# result_schema = ctx.deps.result_schema -# -# ctx.state.message_history.append(self.request) -# -# # Check usage -# if ctx.deps.usage_limits: -# ctx.deps.usage_limits.check_before_request(ctx.state.usage) -# -# # Increment run_step -# ctx.state.run_step += 1 -# -# with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): -# model_request_parameters = await _prepare_request_parameters(ctx) -# -# # Actually make the model request -# model_settings = merge_model_settings(ctx.deps.model_settings, None) -# with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span: -# async with ctx.deps.model.request_stream( -# ctx.state.message_history, model_settings, model_request_parameters -# ) as streamed_response: -# ctx.state.usage.requests += 1 -# model_req_span.set_attribute('response_type', streamed_response.__class__.__name__) -# # We want to end the "model request" span here, but we can't exit the context manager -# # in the traditional way -# model_req_span.__exit__(None, None, None) -# -# with _logfire.span('handle model response') as handle_span: -# received_text = False -# -# async for maybe_part_event in streamed_response: -# if isinstance(maybe_part_event, _messages.PartStartEvent): -# new_part = maybe_part_event.part -# if isinstance(new_part, _messages.TextPart): -# received_text = True -# if _allow_text_result(result_schema): -# handle_span.message = 'handle model response -> final result' -# streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx) -# self._result = End(streamed_run_result) -# yield self._result -# return -# elif isinstance(new_part, _messages.ToolCallPart): -# if result_schema is not None and (match := result_schema.find_tool([new_part])): -# call, _ = match -# handle_span.message = 'handle model response -> final result' -# streamed_run_result = _build_streamed_run_result( -# streamed_response, call.tool_name, ctx -# ) -# self._result = End(streamed_run_result) -# yield self._result -# return -# else: -# assert_never(new_part) -# -# tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] -# parts: list[_messages.ModelRequestPart] = [] -# model_response = streamed_response.get() -# if not model_response.parts: -# raise exceptions.UnexpectedModelBehavior('Received empty model response') -# ctx.state.message_history.append(model_response) -# -# run_context = _build_run_context(ctx) -# for p in model_response.parts: -# if isinstance(p, _messages.ToolCallPart): -# if tool := ctx.deps.function_tools.get(p.tool_name): -# tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) -# else: -# parts.append(_unknown_tool(p.tool_name, ctx)) -# -# if received_text and not tasks and not parts: -# # Can only get here if self._allow_text_result returns `False` for the provided result_schema -# ctx.state.increment_retries(ctx.deps.max_result_retries) -# self._result = StreamModelRequestNode[DepsT, NodeRunEndT]( -# _messages.ModelRequest( -# parts=[ -# _messages.RetryPromptPart( -# content='Plain text responses are not permitted, please call one of the functions instead.', -# ) -# ] -# ) -# ) -# yield self._result -# return -# -# with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): -# task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) -# parts.extend(task_results) -# -# next_request = _messages.ModelRequest(parts=parts) -# if any(isinstance(part, _messages.RetryPromptPart) for part in parts): -# try: -# ctx.state.increment_retries(ctx.deps.max_result_retries) -# except: -# # TODO: This is janky, so I think we should probably change it, but how? -# ctx.state.message_history.append(next_request) -# raise -# -# handle_span.set_attribute('tool_responses', parts) -# tool_responses_str = ' '.join(r.part_kind for r in parts) -# handle_span.message = f'handle model response -> {tool_responses_str}' -# # the model_response should have been fully streamed by now, we can add its usage -# streamed_response_usage = streamed_response.usage() -# run_context.usage.incr(streamed_response_usage) -# ctx.deps.usage_limits.check_tokens(run_context.usage) -# self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request) -# yield self._result -# return - - @dataclasses.dataclass class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]): """Produce the final result of the run.""" @@ -820,19 +682,3 @@ def build_agent_graph( auto_instrument=False, ) return graph - - -# def build_agent_stream_graph( -# name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None -# ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]: -# nodes = [ -# StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], -# StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], -# ] -# graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]( -# nodes=nodes, -# name=name or 'Agent', -# state_type=GraphAgentState, -# run_end_type=result.StreamedRunResult[DepsT, result_type], -# ) -# return graph diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index e66591574..792d28879 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -261,10 +261,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noinspection PyUnreachableCode yield - @abstractmethod - def timestamp(self) -> datetime: - """Get the timestamp of the response.""" - raise NotImplementedError() + def get(self) -> ModelResponse: + """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + return ModelResponse( + parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp() + ) def model_name(self) -> str: """Get the model name of the response.""" @@ -274,11 +275,10 @@ def usage(self) -> Usage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage - def get(self) -> ModelResponse: - """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" - return ModelResponse( - parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp() - ) + @abstractmethod + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + raise NotImplementedError() async def stream_events(self) -> AsyncIterator[ModelResponseStreamEvent]: return self.__aiter__() diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 710cfedee..4f6706340 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -559,7 +559,8 @@ async def next( """Note: this method behaves very similarly to an async generator's `asend` method.""" if not self._started: raise exceptions.GraphRuntimeError( - 'You must enter the GraphRun as a contextmanager before you can call `next` on it.' + 'You must enter the GraphRun as a contextmanager (using `with ...`)' + ' before you can iterate over it or call `next` on it.' ) history = self.history @@ -611,8 +612,4 @@ async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" if self._result: raise StopAsyncIteration - if not self._started: - raise exceptions.GraphRuntimeError( - 'You must enter the GraphRun as a contextmanager before you can iterate over it.' - ) return await self.next(self._next_node) From 0c3b48de884deaf12fd804e9ccccebafcffe445a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:48:58 -0700 Subject: [PATCH 03/28] Move definition of MarkFinalResult to result module --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 ++-- pydantic_ai_slim/pydantic_ai/agent.py | 23 +++++++++---------- .../pydantic_ai/models/__init__.py | 21 ++--------------- pydantic_ai_slim/pydantic_ai/result.py | 15 ++++++++++++ 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5d1421915..6899b8a80 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -24,8 +24,8 @@ models, usage as _usage, ) -from .models import MarkFinalResult, ModelRequestParameters, StreamedResponse -from .result import ResultDataT +from .models import ModelRequestParameters, StreamedResponse +from .result import MarkFinalResult, ResultDataT from .settings import ModelSettings, merge_model_settings from .tools import ( RunContext, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 709e48a8f..117aa28cd 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -27,8 +27,7 @@ usage as _usage, ) from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export -from .models import MarkFinalResult -from .result import ResultDataT, StreamedRunResult +from .result import MarkFinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -1041,7 +1040,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: def _build_graph( self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], models.MarkFinalResult[Any]]: + ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[Any]]: return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) def _prepare_result_schema( @@ -1060,7 +1059,7 @@ def _prepare_result_schema( @dataclasses.dataclass class AgentRun(Generic[AgentDepsT, ResultDataT]): graph_run: GraphRun[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], models.MarkFinalResult[ResultDataT] + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] ] def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: @@ -1131,7 +1130,7 @@ def is_ended(self) -> bool: return self.graph_run.is_ended @property - def result(self) -> models.MarkFinalResult[ResultDataT]: + def result(self) -> MarkFinalResult[ResultDataT]: return self.graph_run.result @property @@ -1147,15 +1146,15 @@ async def next( node: BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - models.MarkFinalResult[ResultDataT], + MarkFinalResult[ResultDataT], ], ) -> ( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - models.MarkFinalResult[ResultDataT], + MarkFinalResult[ResultDataT], ] - | End[models.MarkFinalResult[ResultDataT]] + | End[MarkFinalResult[ResultDataT]] ): return await self.graph_run.next(node) @@ -1184,9 +1183,9 @@ def __aiter__( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - models.MarkFinalResult[ResultDataT], + MarkFinalResult[ResultDataT], ] - | End[models.MarkFinalResult[ResultDataT]] + | End[MarkFinalResult[ResultDataT]] ]: return self @@ -1196,9 +1195,9 @@ async def __anext__( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - models.MarkFinalResult[ResultDataT], + MarkFinalResult[ResultDataT], ] - | End[models.MarkFinalResult[ResultDataT]] + | End[MarkFinalResult[ResultDataT]] ): """Use the last returned node as the input to `Graph.next`.""" return await self.graph_run.__anext__() diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 792d28879..0d0d29a39 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -12,11 +12,11 @@ from dataclasses import dataclass, field from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING import httpx import logfire_api -from typing_extensions import Literal, TypeVar +from typing_extensions import Literal from .. import _utils, messages as _messages from .._parts_manager import ModelResponsePartsManager @@ -167,23 +167,6 @@ `KnownModelName` is provided as a concise way to specify a model. """ -ResultDataT = TypeVar('ResultDataT', covariant=True) - - -@dataclass -class MarkFinalResult(Generic[ResultDataT]): - """Marker class to indicate that the result is the final result. - - This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. - - It also avoids problems in the case where the result type is itself `None`, but is set. - """ - - data: ResultDataT - """The final result data.""" - tool_name: str | None - """Name of the final result tool, None if the result is a string.""" - @dataclass class ModelRequestParameters: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index bd33d3488..2ae72b164 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -291,6 +291,21 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: await self._on_complete() +@dataclass +class MarkFinalResult(Generic[ResultDataT]): + """Marker class to indicate that the result is the final result. + + This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. + + It also avoids problems in the case where the result type is itself `None`, but is set. + """ + + data: ResultDataT + """The final result data.""" + tool_name: str | None + """Name of the final result tool, None if the result is a string.""" + + def _get_usage_checking_stream_response( stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], limits: UsageLimits | None, From 9280adf9d6d79b347cf6bec9137cbb8d000ac529 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:50:35 -0700 Subject: [PATCH 04/28] A bit more clean-up --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 0d0d29a39..7a75c084f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -15,7 +15,6 @@ from typing import TYPE_CHECKING import httpx -import logfire_api from typing_extensions import Literal from .. import _utils, messages as _messages @@ -28,8 +27,6 @@ if TYPE_CHECKING: from ..tools import ToolDefinition -_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') - KnownModelName = Literal[ 'anthropic:claude-3-5-haiku-latest', 'anthropic:claude-3-5-sonnet-latest', From 7d55afd0d3d8d798606530718c01ba31b2796461 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:14:56 -0700 Subject: [PATCH 05/28] Update call_id logic --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 3 +-- pydantic_ai_slim/pydantic_ai/messages.py | 11 +++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 6899b8a80..76c3a3b63 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -2,7 +2,6 @@ import asyncio import dataclasses -import uuid from abc import ABC from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager @@ -520,7 +519,7 @@ async def process_function_tools( run_context = build_run_context(ctx) calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] - call_index_to_event_id: dict[int, uuid.UUID] = {} + call_index_to_event_id: dict[int, str] = {} for call in tool_calls: if call.tool_name == result_tool_name and not found_used_result_tool: found_used_result_tool = True diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 9f39e288b..111ce63c1 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -454,11 +454,14 @@ class FunctionToolCallEvent: part: ToolCallPart """The (function) tool call to make.""" - call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False) - """An ID used to match the call to its result.""" - event_kind: Literal['function_tool_call'] = field(default='function_tool_call', repr=False) + call_id: str = field(init=False) + """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id.""" + event_kind: Literal['function_tool_call'] = field(default='function_tool_call', init=False, repr=False) """Event type identifier, used as a discriminator.""" + def __post_init__(self): + self.call_id = self.part.tool_call_id or str(uuid.uuid4()) + @dataclass class FunctionToolResultEvent: @@ -466,7 +469,7 @@ class FunctionToolResultEvent: result: ToolReturnPart | RetryPromptPart """The result of the call to the function tool.""" - call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False) + call_id: str """An ID used to match the result to its original call.""" event_kind: Literal['function_tool_result'] = field(default='function_tool_result', repr=False) """Event type identifier, used as a discriminator.""" From 63df8ba9b874a097c5c2737cabfbdc07d81fbe9d Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:19:09 -0700 Subject: [PATCH 06/28] Minor fixes --- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 117aa28cd..92fbe8828 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1162,7 +1162,7 @@ def graph_ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_grap return GraphRunContext(self.graph_run.state, self.graph_run.deps) def __await__(self) -> Generator[Any, Any, Self]: - """Run the graph until it ends, and return the final result.""" + """Run the agent graph until it ends, and return the final result.""" async def _run(): await self.graph_run diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 111ce63c1..ab159abd9 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -471,7 +471,7 @@ class FunctionToolResultEvent: """The result of the call to the function tool.""" call_id: str """An ID used to match the result to its original call.""" - event_kind: Literal['function_tool_result'] = field(default='function_tool_result', repr=False) + event_kind: Literal['function_tool_result'] = field(default='function_tool_result', init=False, repr=False) """Event type identifier, used as a discriminator.""" From 6c650955f4698a51f6edcc001ac2fcb4ed597ac9 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 16:08:48 -0700 Subject: [PATCH 07/28] Update some things --- docs/agents.md | 20 +-- docs/api/models/function.md | 4 +- docs/message-history.md | 78 +++-------- docs/tools.md | 34 +---- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 63 ++++----- pydantic_ai_slim/pydantic_ai/agent.py | 57 +++++++- pydantic_ai_slim/pydantic_ai/messages.py | 24 ++-- pydantic_graph/pydantic_graph/graph.py | 10 ++ tests/test_agent.py | 49 +++---- tests/test_parts_manager.py | 133 ++++++------------- tests/test_streaming.py | 12 -- 11 files changed, 190 insertions(+), 294 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 19da9f3a2..39c54b327 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -472,23 +472,17 @@ with capture_run_messages() as messages: # (2)! UserPromptPart( content='Please get me the volume of a box with size 6.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ) - ], - kind='request', + ] ), ModelResponse( parts=[ ToolCallPart( - tool_name='calc_volume', - args={'size': 6}, - tool_call_id=None, - part_kind='tool-call', + tool_name='calc_volume', args={'size': 6}, tool_call_id=None ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ModelRequest( parts=[ @@ -497,23 +491,17 @@ with capture_run_messages() as messages: # (2)! tool_name='calc_volume', tool_call_id=None, timestamp=datetime.datetime(...), - part_kind='retry-prompt', ) - ], - kind='request', + ] ), ModelResponse( parts=[ ToolCallPart( - tool_name='calc_volume', - args={'size': 6}, - tool_call_id=None, - part_kind='tool-call', + tool_name='calc_volume', args={'size': 6}, tool_call_id=None ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ diff --git a/docs/api/models/function.md b/docs/api/models/function.md index d24c87c18..6049a1122 100644 --- a/docs/api/models/function.md +++ b/docs/api/models/function.md @@ -28,10 +28,8 @@ async def model_function( UserPromptPart( content='Testing my agent...', timestamp=datetime.datetime(...), - part_kind='user-prompt', ) - ], - kind='request', + ] ) ] """ diff --git a/docs/message-history.md b/docs/message-history.md index d538112f8..fe94481c3 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -42,29 +42,21 @@ print(result.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart( - content='Be a helpful assistant.', - dynamic_ref=None, - part_kind='system-prompt', - ), + SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.', - part_kind='text', + content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ @@ -88,17 +80,13 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', - dynamic_ref=None, - part_kind='system-prompt', + content='Be a helpful assistant.', dynamic_ref=None ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ) ] """ @@ -117,28 +105,22 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', - dynamic_ref=None, - part_kind='system-prompt', + content='Be a helpful assistant.', dynamic_ref=None ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.', - part_kind='text', + content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], model_name='function:stream_model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ @@ -173,50 +155,38 @@ print(result2.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart( - content='Be a helpful assistant.', - dynamic_ref=None, - part_kind='system-prompt', - ), + SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.', - part_kind='text', + content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), - part_kind='user-prompt', ) - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.', - part_kind='text', + content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ @@ -253,50 +223,38 @@ print(result2.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart( - content='Be a helpful assistant.', - dynamic_ref=None, - part_kind='system-prompt', - ), + SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.', - part_kind='text', + content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), - part_kind='user-prompt', ) - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.', - part_kind='text', + content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ diff --git a/docs/tools.md b/docs/tools.md index 2de55701c..0468c125e 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -73,25 +73,17 @@ print(dice_result.all_messages()) SystemPromptPart( content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.", dynamic_ref=None, - part_kind='system-prompt', ), UserPromptPart( content='My guess is 4', timestamp=datetime.datetime(...), - part_kind='user-prompt', ), - ], - kind='request', + ] ), ModelResponse( - parts=[ - ToolCallPart( - tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call' - ) - ], + parts=[ToolCallPart(tool_name='roll_die', args={}, tool_call_id=None)], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ModelRequest( parts=[ @@ -100,23 +92,13 @@ print(dice_result.all_messages()) content='4', tool_call_id=None, timestamp=datetime.datetime(...), - part_kind='tool-return', ) - ], - kind='request', + ] ), ModelResponse( - parts=[ - ToolCallPart( - tool_name='get_player_name', - args={}, - tool_call_id=None, - part_kind='tool-call', - ) - ], + parts=[ToolCallPart(tool_name='get_player_name', args={}, tool_call_id=None)], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ModelRequest( parts=[ @@ -125,21 +107,17 @@ print(dice_result.all_messages()) content='Anne', tool_call_id=None, timestamp=datetime.datetime(...), - part_kind='tool-return', ) - ], - kind='request', + ] ), ModelResponse( parts=[ TextPart( - content="Congratulations Anne, you guessed correctly! You're a winner!", - part_kind='text', + content="Congratulations Anne, you guessed correctly! You're a winner!" ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), - kind='response', ), ] """ diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 76c3a3b63..a7ef0a1b0 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -314,14 +314,14 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N model_response: _messages.ModelResponse _stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) - _next_node: ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT] | None = field( + _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[MarkFinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007 + ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[MarkFinalResult[NodeRunEndT]]]: # noqa UP007 async with self.stream(ctx): pass @@ -343,7 +343,7 @@ async def stream( # Set the next node based on the final state of the stream next_node = self._next_node - if isinstance(next_node, FinalResultNode): + if isinstance(next_node, End): handle_span.set_attribute('result', next_node.data) handle_span.message = 'handle model response -> final result' elif tool_responses := self._tool_responses: @@ -423,17 +423,37 @@ async def _handle_tool_calls( yield event if final_result: - self._next_node = FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses) + self._next_node = self._handle_final_result(ctx, final_result, tool_responses) else: if tool_responses: parts.extend(tool_responses) self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) + def _handle_final_result( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + final_result: MarkFinalResult[NodeRunEndT], + tool_responses: list[_messages.ModelRequestPart], + ) -> End[MarkFinalResult[NodeRunEndT]]: + run_span = ctx.deps.run_span + usage = ctx.state.usage + messages = ctx.state.message_history + + # For backwards compatibility, append a new ModelRequest using the tool returns and retries + if tool_responses: + messages.append(_messages.ModelRequest(parts=tool_responses)) + + run_span.set_attribute('usage', usage) + run_span.set_attribute('all_messages', messages) + + # End the run with self.data + return End(final_result) + async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], texts: list[str], - ) -> ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT]: + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[MarkFinalResult[NodeRunEndT]]: result_schema = ctx.deps.result_schema text = '\n\n'.join(texts) @@ -445,7 +465,8 @@ async def _handle_text_response( ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: - return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None)) + # The following cast is safe because we know `str` is an allowed result type + return self._handle_final_result(ctx, MarkFinalResult(result_data, tool_name=None), []) else: ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT]( @@ -459,35 +480,6 @@ async def _handle_text_response( ) -@dataclasses.dataclass -class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]): - """Produce the final result of the run.""" - - data: MarkFinalResult[NodeRunEndT] - """The final result data.""" - extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list) - - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> End[MarkFinalResult[NodeRunEndT]]: - run_span = ctx.deps.run_span - usage = ctx.state.usage - messages = ctx.state.message_history - - # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries - if self.extra_parts: - messages.append(_messages.ModelRequest(parts=self.extra_parts)) - - # TODO: Set this attribute somewhere - # handle_span = self.handle_model_response_span - # handle_span.set_attribute('final_data', self.data) - run_span.set_attribute('usage', usage) - run_span.set_attribute('all_messages', messages) - - # End the run with self.data - return End(self.data) - - def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: return RunContext[DepsT]( deps=ctx.deps.user_deps, @@ -671,7 +663,6 @@ def build_agent_graph( UserPromptNode[DepsT], ModelRequestNode[DepsT], HandleResponseNode[DepsT], - FinalResultNode[DepsT, ResultT], ) graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]( nodes=nodes, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 92fbe8828..23cc43b75 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -230,7 +230,7 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: ... + ) -> AgentRun[AgentDepsT, RunResultDataT]: ... def run( self, @@ -244,9 +244,18 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: + ) -> AgentRun[AgentDepsT, Any]: """Run the agent with a user prompt in async mode. + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + returns an AgentRun object. The AgentRun functions as a handle that can be used to iterate over the graph and + obtain the final result. The AgentRun also provides methods to access the full message history, new messages, + and usage statistics. + + The AgentRun can be awaited to get the final result of the run, or entered as a context manager to + obtain an iterator over the graph nodes. You can even use the AgentRun as an async generator to override the + execution of the graph if desired. See the documentation of AgentRun for more details. + Example: ```python from pydantic_ai import Agent @@ -254,11 +263,47 @@ def run( agent = Agent('openai:gpt-4o') async def main(): - result = await agent.run('What is the capital of France?') - print(result.data) + agent_run = await agent.run('What is the capital of France?') + print(agent_run.data) #> Paris ``` + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + with agent.run('What is the capital of France?') as agent_run: + async for node in agent_run: + print(node) + ''' + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ) + ''' + ''' + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + ) + ) + ''' + #> End(data=MarkFinalResult(data='Paris', tool_name=None)) + print(agent_run.data) + #> Paris + ``` + Args: result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. @@ -374,7 +419,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: ... + ) -> AgentRun[AgentDepsT, RunResultDataT]: ... def run_sync( self, @@ -388,7 +433,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: + ) -> AgentRun[AgentDepsT, Any]: """Run the agent with a user prompt synchronously. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index ab159abd9..165ed9912 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -28,7 +28,7 @@ class SystemPromptPart: Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information. """ - part_kind: Literal['system-prompt'] = 'system-prompt' + part_kind: Literal['system-prompt'] = field(default='system-prompt', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" @@ -46,7 +46,7 @@ class UserPromptPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the prompt.""" - part_kind: Literal['user-prompt'] = 'user-prompt' + part_kind: Literal['user-prompt'] = field(default='user-prompt', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" @@ -69,7 +69,7 @@ class ToolReturnPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the tool returned.""" - part_kind: Literal['tool-return'] = 'tool-return' + part_kind: Literal['tool-return'] = field(default='tool-return', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" def model_response_str(self) -> str: @@ -123,7 +123,7 @@ class RetryPromptPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the retry was triggered.""" - part_kind: Literal['retry-prompt'] = 'retry-prompt' + part_kind: Literal['retry-prompt'] = field(default='retry-prompt', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" def model_response(self) -> str: @@ -149,7 +149,7 @@ class ModelRequest: parts: list[ModelRequestPart] """The parts of the user message.""" - kind: Literal['request'] = 'request' + kind: Literal['request'] = field(default='request', init=False, repr=False) """Message type identifier, this is available on all parts as a discriminator.""" @@ -160,7 +160,7 @@ class TextPart: content: str """The text content of the response.""" - part_kind: Literal['text'] = 'text' + part_kind: Literal['text'] = field(default='text', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" def has_content(self) -> bool: @@ -184,7 +184,7 @@ class ToolCallPart: tool_call_id: str | None = None """Optional tool call identifier, this is used by some models including OpenAI.""" - part_kind: Literal['tool-call'] = 'tool-call' + part_kind: Literal['tool-call'] = field(default='tool-call', init=False, repr=False) """Part type identifier, this is available on all parts as a discriminator.""" def args_as_dict(self) -> dict[str, Any]: @@ -237,7 +237,7 @@ class ModelResponse: If the model provides a timestamp in the response (as OpenAI does) that will be used. """ - kind: Literal['response'] = 'response' + kind: Literal['response'] = field(default='response', init=False, repr=False) """Message type identifier, this is available on all parts as a discriminator.""" @@ -255,7 +255,7 @@ class TextPartDelta: content_delta: str """The incremental text content to add to the existing `TextPart` content.""" - part_delta_kind: Literal['text'] = 'text' + part_delta_kind: Literal['text'] = field(default='text', init=False, repr=False) """Part delta type identifier, used as a discriminator.""" def apply(self, part: ModelResponsePart) -> TextPart: @@ -295,7 +295,7 @@ class ToolCallPartDelta: Note this is never treated as a delta — it can replace None, but otherwise if a non-matching value is provided an error will be raised.""" - part_delta_kind: Literal['tool_call'] = 'tool_call' + part_delta_kind: Literal['tool_call'] = field(default='tool_call', init=False, repr=False) """Part delta type identifier, used as a discriminator.""" def as_part(self) -> ToolCallPart | None: @@ -426,7 +426,7 @@ class PartStartEvent: part: ModelResponsePart """The newly started `ModelResponsePart`.""" - event_kind: Literal['part_start'] = 'part_start' + event_kind: Literal['part_start'] = field(default='part_start', init=False, repr=False) """Event type identifier, used as a discriminator.""" @@ -440,7 +440,7 @@ class PartDeltaEvent: delta: ModelResponsePartDelta """The delta to apply to the specified part.""" - event_kind: Literal['part_delta'] = 'part_delta' + event_kind: Literal['part_delta'] = field(default='part_delta', init=False, repr=False) """Event type identifier, used as a discriminator.""" diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4f6706340..3a27ca8b5 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -137,6 +137,9 @@ def run( ) -> GraphRun[StateT, DepsT, T]: """Run the graph from a starting node until it ends. + The returned GraphRun can be awaited (or used as an async iterator) to drive the graph to completion. + TODO: Need to add a more detailed message here explaining that this can behave like a coroutine or a context manager etc. + Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. @@ -511,6 +514,8 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]): """A stateful run of a graph. After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. + + TODO: this requires some heavy weight API documentation. """ def __init__( @@ -588,6 +593,11 @@ async def _run() -> typing_extensions.Self: return _run().__await__() def __enter__(self) -> typing_extensions.Self: + """Open a span for the graph run. + + Note that we _require_ that the graph run is used as a context manager when iterating over nodes + so that we can ensure that the span covers the time range during which the iteration happens. + """ if self._started: raise exceptions.GraphRuntimeError('A GraphRun can only be started once.') diff --git a/tests/test_agent.py b/tests/test_agent.py index 44de33e91..f5f87834c 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1340,17 +1340,15 @@ async def func() -> str: [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar', part_kind='system-prompt'), - SystemPromptPart(content=dynamic_value, part_kind='system-prompt'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + SystemPromptPart(content='Foobar'), + SystemPromptPart(content=dynamic_value), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], - kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ] ) @@ -1363,30 +1361,25 @@ async def func() -> str: [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart(content='Foobar'), SystemPromptPart( content='A', # Remains the same - part_kind='system-prompt', ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], - kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ModelRequest( - parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], - kind='request', + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))], ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ] ) @@ -1412,21 +1405,18 @@ async def func(): [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart(content='Foobar'), SystemPromptPart( content=dynamic_value, - part_kind='system-prompt', dynamic_ref=func.__qualname__, ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], - kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ] ) @@ -1439,31 +1429,26 @@ async def func(): [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart(content='Foobar'), SystemPromptPart( content='B', - part_kind='system-prompt', dynamic_ref=func.__qualname__, ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], - kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ModelRequest( - parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], - kind='request', + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))], ), ModelResponse( - parts=[TextPart(content='success (no tool calls)', part_kind='text')], + parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - kind='response', ), ] ) diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py index a93e7fb31..13cb81235 100644 --- a/tests/test_parts_manager.py +++ b/tests/test_parts_manager.py @@ -24,56 +24,32 @@ def test_handle_text_deltas(vendor_part_id: str | None): assert manager.get_parts() == [] event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) + assert manager.get_parts() == snapshot([TextPart(content='hello ')]) event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot([TextPart(content='hello world', part_kind='text')]) + assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) + assert manager.get_parts() == snapshot([TextPart(content='hello world')]) def test_handle_dovetailed_text_deltas(): manager = ModelResponsePartsManager() event = manager.handle_text_delta(vendor_part_id='first', content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) + assert manager.get_parts() == snapshot([TextPart(content='hello ')]) event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ') - assert event == snapshot( - PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot( - [TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] - ) + assert event == snapshot(PartStartEvent(index=1, part=TextPart(content='goodbye '))) + assert manager.get_parts() == snapshot([TextPart(content='hello '), TextPart(content='goodbye ')]) event = manager.handle_text_delta(vendor_part_id='first', content='world') - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot( - [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] - ) + assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) + assert manager.get_parts() == snapshot([TextPart(content='hello world'), TextPart(content='goodbye ')]) event = manager.handle_text_delta(vendor_part_id='second', content='Samuel') - assert event == snapshot( - PartDeltaEvent( - index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta' - ) - ) - assert manager.get_parts() == snapshot( - [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye Samuel', part_kind='text')] - ) + assert event == snapshot(PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Samuel'))) + assert manager.get_parts() == snapshot([TextPart(content='hello world'), TextPart(content='goodbye Samuel')]) def test_handle_tool_call_deltas(): @@ -89,36 +65,33 @@ def test_handle_tool_call_deltas(): assert event == snapshot( PartStartEvent( index=0, - part=ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), - event_kind='part_start', + part=ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None), ) ) - assert manager.get_parts() == snapshot( - [ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] - ) + assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None)]) event = manager.handle_tool_call_delta(vendor_part_id='first', tool_name='1', args=None, tool_call_id=None) assert event == snapshot( PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta='1', args_delta=None, tool_call_id=None, part_delta_kind='tool_call' + tool_name_delta='1', + args_delta=None, + tool_call_id=None, ), - event_kind='part_delta', ) ) - assert manager.get_parts() == snapshot( - [ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] - ) + assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None)]) event = manager.handle_tool_call_delta(vendor_part_id='first', tool_name=None, args='"value1"}', tool_call_id=None) assert event == snapshot( PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta=None, args_delta='"value1"}', tool_call_id=None, part_delta_kind='tool_call' + tool_name_delta=None, + args_delta='"value1"}', + tool_call_id=None, ), - event_kind='part_delta', ) ) assert manager.get_parts() == snapshot( @@ -127,7 +100,6 @@ def test_handle_tool_call_deltas(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, - part_kind='tool-call', ) ] ) @@ -144,7 +116,6 @@ def test_handle_tool_call_deltas_without_vendor_id(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, - part_kind='tool-call', ) ] ) @@ -155,8 +126,8 @@ def test_handle_tool_call_deltas_without_vendor_id(): manager.handle_tool_call_delta(vendor_part_id=None, tool_name='tool2', args='"value1"}', tool_call_id=None) assert manager.get_parts() == snapshot( [ - ToolCallPart(tool_name='tool2', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), - ToolCallPart(tool_name='tool2', args='"value1"}', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool2', args='{"arg1":', tool_call_id=None), + ToolCallPart(tool_name='tool2', args='"value1"}', tool_call_id=None), ] ) @@ -169,23 +140,20 @@ def test_handle_tool_call_part(): assert event == snapshot( PartStartEvent( index=0, - part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), - event_kind='part_start', + part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None), ) ) # Add a delta manager.handle_tool_call_delta(vendor_part_id='second', tool_name='tool1', args=None, tool_call_id=None) - assert manager.get_parts() == snapshot( - [ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] - ) + assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None)]) # Override it with handle_tool_call_part manager.handle_tool_call_part(vendor_part_id='second', tool_name='tool1', args='{}', tool_call_id=None) assert manager.get_parts() == snapshot( [ - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), ] ) @@ -194,9 +162,10 @@ def test_handle_tool_call_part(): PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta=None, args_delta='"value1"}', tool_call_id=None, part_delta_kind='tool_call' + tool_name_delta=None, + args_delta='"value1"}', + tool_call_id=None, ), - event_kind='part_delta', ) ) assert manager.get_parts() == snapshot( @@ -205,9 +174,8 @@ def test_handle_tool_call_part(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, - part_kind='tool-call', ), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), ] ) @@ -216,8 +184,7 @@ def test_handle_tool_call_part(): assert event == snapshot( PartStartEvent( index=2, - part=ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), - event_kind='part_start', + part=ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), ) ) assert manager.get_parts() == snapshot( @@ -226,10 +193,9 @@ def test_handle_tool_call_part(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, - part_kind='tool-call', ), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), ] ) @@ -240,10 +206,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non manager = ModelResponsePartsManager() event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ') - assert event == snapshot( - PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') - ) - assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) + assert manager.get_parts() == snapshot([TextPart(content='hello ')]) event = manager.handle_tool_call_delta( vendor_part_id=tool_vendor_part_id, tool_name='tool1', args='{"arg1":', tool_call_id='abc' @@ -251,8 +215,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non assert event == snapshot( PartStartEvent( index=1, - part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), - event_kind='part_start', + part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), ) ) @@ -261,27 +224,22 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non assert event == snapshot( PartStartEvent( index=2, - part=TextPart(content='world', part_kind='text'), - event_kind='part_start', + part=TextPart(content='world'), ) ) assert manager.get_parts() == snapshot( [ - TextPart(content='hello ', part_kind='text'), - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), - TextPart(content='world', part_kind='text'), + TextPart(content='hello '), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), + TextPart(content='world'), ] ) else: - assert event == snapshot( - PartDeltaEvent( - index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' - ) - ) + assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) assert manager.get_parts() == snapshot( [ - TextPart(content='hello world', part_kind='text'), - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), + TextPart(content='hello world'), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), ] ) @@ -313,7 +271,6 @@ def test_tool_call_id_delta(): tool_name='tool1', args='{"arg1":', tool_call_id=None, - part_kind='tool-call', ) ] ) @@ -325,7 +282,6 @@ def test_tool_call_id_delta(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id='id2', - part_kind='tool-call', ) ] ) @@ -346,7 +302,6 @@ def test_tool_call_id_delta_failure(apply_to_delta: bool): tool_name='tool1', args='{"arg1":', tool_call_id='id1', - part_kind='tool-call', ) ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f95be4c13..5d07bf4cf 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -592,10 +592,8 @@ def another_tool(y: int) -> int: # pragma: no cover UserPromptPart( content='test early strategy with final ' 'result in middle', timestamp=IsNow(tz=datetime.timezone.utc), - part_kind='user-prompt', ) ], - kind='request', ), ModelResponse( parts=[ @@ -603,30 +601,25 @@ def another_tool(y: int) -> int: # pragma: no cover tool_name='regular_tool', args='{"x": 1}', tool_call_id=None, - part_kind='tool-call', ), ToolCallPart( tool_name='final_result', args='{"value": "final"}', tool_call_id=None, - part_kind='tool-call', ), ToolCallPart( tool_name='another_tool', args='{"y": 2}', tool_call_id=None, - part_kind='tool-call', ), ToolCallPart( tool_name='unknown_tool', args='{"value": "???"}', tool_call_id=None, - part_kind='tool-call', ), ], model_name='function:sf', timestamp=IsNow(tz=datetime.timezone.utc), - kind='response', ), ModelRequest( parts=[ @@ -635,21 +628,18 @@ def another_tool(y: int) -> int: # pragma: no cover content='Tool not executed - a final ' 'result was already processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), - part_kind='tool-return', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), - part_kind='tool-return', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final ' 'result was already processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), - part_kind='tool-return', ), RetryPromptPart( content='Unknown tool name: ' @@ -659,10 +649,8 @@ def another_tool(y: int) -> int: # pragma: no cover tool_name=None, tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), - part_kind='retry-prompt', ), ], - kind='request', ), ] ) From db56e315ff9225af79ec327cea6d5c9cf4d683b9 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 11 Feb 2025 18:00:01 -0700 Subject: [PATCH 08/28] Update some comments etc. --- pydantic_ai_slim/pydantic_ai/agent.py | 48 +++++++++++++++++---------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 23cc43b75..65aae0000 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -245,18 +245,26 @@ def run( usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRun[AgentDepsT, Any]: - """Run the agent with a user prompt in async mode. + """Run the agent with a user prompt. This method builds an internal agent graph (using system prompts, tools and result schemas) and then - returns an AgentRun object. The AgentRun functions as a handle that can be used to iterate over the graph and - obtain the final result. The AgentRun also provides methods to access the full message history, new messages, - and usage statistics. - - The AgentRun can be awaited to get the final result of the run, or entered as a context manager to - obtain an iterator over the graph nodes. You can even use the AgentRun as an async generator to override the - execution of the graph if desired. See the documentation of AgentRun for more details. - - Example: + returns an `AgentRun` object. The `AgentRun` functions as a handle that can be used to iterate over the graph + and obtain the final result. The AgentRun also provides methods to access the full message history, + new messages, and usage statistics. + + The returned `AgentRun` object should always be immediately used in one of two ways: + * Via `await` (i.e., `await agent.run(...)`), which will execute the graph run and return the final result + * This is the API you should use if you just want the end result and are not interested in the execution details + or consuming streaming updates + * As a context manager (i.e., `with agent.run(...) as agent_run:`), which will return an async iterator over + the graph nodes, and which can also be used as an async generator to override the execution of the graph. + * This is the API you should use if you want to consume the graph execution in a streaming manner, + or if you want to consume the stream of events coming from individual requests to the LLM, or the stream + of events coming from the execution of tools. + + For more details, see the documentation of `AgentRun`. + + Example (with `await`): ```python from pydantic_ai import Agent @@ -268,7 +276,7 @@ async def main(): #> Paris ``` - Example: + Example (with a context manager): ```python from pydantic_ai import Agent @@ -381,7 +389,6 @@ async def main(): system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) - # Actually run return AgentRun( graph.run( start_node, @@ -577,7 +584,7 @@ async def main(): while True: if isinstance(node, _agent_graph.ModelRequestNode): node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) - graph_ctx = agent_run.graph_ctx() + graph_ctx = agent_run.get_graph_ctx() async with node.stream(graph_ctx) as streamed_response: async def stream_to_final( @@ -1201,13 +1208,15 @@ async def next( ] | End[MarkFinalResult[ResultDataT]] ): + # TODO: It would be nice to expose a synchronous interface for this, to be able to + # synchronously iterate over the agent graph. I don't think this would be hard to do, + # but I'm having a hard time coming up with an API that fits nicely along side the current `run_sync`. + # The use of `await` provides an easy way to signal that you just want the result, but it's less + # clear to me what the analogous thing should be for synchronous code. return await self.graph_run.next(node) - def graph_ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: - return GraphRunContext(self.graph_run.state, self.graph_run.deps) - def __await__(self) -> Generator[Any, Any, Self]: - """Run the agent graph until it ends, and return the final result.""" + """Run the agent graph until it ends, and return self.""" async def _run(): await self.graph_run @@ -1246,3 +1255,8 @@ async def __anext__( ): """Use the last returned node as the input to `Graph.next`.""" return await self.graph_run.__anext__() + + def get_graph_ctx( + self, + ) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: + return GraphRunContext(self.graph_run.state, self.graph_run.deps) From 9af98e8b7b16d26613fa59c1df0b3e76cfab9ea0 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:02:51 -0700 Subject: [PATCH 09/28] Undo kind changes --- docs/agents.md | 20 ++++-- docs/api/models/function.md | 4 +- docs/message-history.md | 78 ++++++++++++++++++------ docs/tools.md | 34 +++++++++-- pydantic_ai_slim/pydantic_ai/agent.py | 51 +++++++++------- pydantic_ai_slim/pydantic_ai/messages.py | 28 ++++----- 6 files changed, 149 insertions(+), 66 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 39c54b327..19da9f3a2 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -472,17 +472,23 @@ with capture_run_messages() as messages: # (2)! UserPromptPart( content='Please get me the volume of a box with size 6.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ) - ] + ], + kind='request', ), ModelResponse( parts=[ ToolCallPart( - tool_name='calc_volume', args={'size': 6}, tool_call_id=None + tool_name='calc_volume', + args={'size': 6}, + tool_call_id=None, + part_kind='tool-call', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ModelRequest( parts=[ @@ -491,17 +497,23 @@ with capture_run_messages() as messages: # (2)! tool_name='calc_volume', tool_call_id=None, timestamp=datetime.datetime(...), + part_kind='retry-prompt', ) - ] + ], + kind='request', ), ModelResponse( parts=[ ToolCallPart( - tool_name='calc_volume', args={'size': 6}, tool_call_id=None + tool_name='calc_volume', + args={'size': 6}, + tool_call_id=None, + part_kind='tool-call', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ diff --git a/docs/api/models/function.md b/docs/api/models/function.md index 6049a1122..d24c87c18 100644 --- a/docs/api/models/function.md +++ b/docs/api/models/function.md @@ -28,8 +28,10 @@ async def model_function( UserPromptPart( content='Testing my agent...', timestamp=datetime.datetime(...), + part_kind='user-prompt', ) - ] + ], + kind='request', ) ] """ diff --git a/docs/message-history.md b/docs/message-history.md index fe94481c3..d538112f8 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -42,21 +42,29 @@ print(result.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), + SystemPromptPart( + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', + ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.' + content='Did you hear about the toothpaste scandal? They called it Colgate.', + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ @@ -80,13 +88,17 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', dynamic_ref=None + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ) ] """ @@ -105,22 +117,28 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', dynamic_ref=None + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.' + content='Did you hear about the toothpaste scandal? They called it Colgate.', + part_kind='text', ) ], model_name='function:stream_model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ @@ -155,38 +173,50 @@ print(result2.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), + SystemPromptPart( + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', + ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.' + content='Did you hear about the toothpaste scandal? They called it Colgate.', + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), + part_kind='user-prompt', ) - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' + content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.', + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ @@ -223,38 +253,50 @@ print(result2.all_messages()) [ ModelRequest( parts=[ - SystemPromptPart(content='Be a helpful assistant.', dynamic_ref=None), + SystemPromptPart( + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', + ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='Did you hear about the toothpaste scandal? They called it Colgate.' + content='Did you hear about the toothpaste scandal? They called it Colgate.', + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), + part_kind='user-prompt', ) - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' + content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.', + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ diff --git a/docs/tools.md b/docs/tools.md index 0468c125e..2de55701c 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -73,17 +73,25 @@ print(dice_result.all_messages()) SystemPromptPart( content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.", dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='My guess is 4', timestamp=datetime.datetime(...), + part_kind='user-prompt', ), - ] + ], + kind='request', ), ModelResponse( - parts=[ToolCallPart(tool_name='roll_die', args={}, tool_call_id=None)], + parts=[ + ToolCallPart( + tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call' + ) + ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ModelRequest( parts=[ @@ -92,13 +100,23 @@ print(dice_result.all_messages()) content='4', tool_call_id=None, timestamp=datetime.datetime(...), + part_kind='tool-return', ) - ] + ], + kind='request', ), ModelResponse( - parts=[ToolCallPart(tool_name='get_player_name', args={}, tool_call_id=None)], + parts=[ + ToolCallPart( + tool_name='get_player_name', + args={}, + tool_call_id=None, + part_kind='tool-call', + ) + ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ModelRequest( parts=[ @@ -107,17 +125,21 @@ print(dice_result.all_messages()) content='Anne', tool_call_id=None, timestamp=datetime.datetime(...), + part_kind='tool-return', ) - ] + ], + kind='request', ), ModelResponse( parts=[ TextPart( - content="Congratulations Anne, you guessed correctly! You're a winner!" + content="Congratulations Anne, you guessed correctly! You're a winner!", + part_kind='text', ) ], model_name='function:model_logic', timestamp=datetime.datetime(...), + kind='response', ), ] """ diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 65aae0000..661b3dfb1 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -283,33 +283,38 @@ async def main(): agent = Agent('openai:gpt-4o') async def main(): + nodes = [] with agent.run('What is the capital of France?') as agent_run: async for node in agent_run: - print(node) - ''' - ModelRequestNode( - request=ModelRequest( - parts=[ - UserPromptPart( - content='What is the capital of France?', - timestamp=datetime.datetime(...), - ) - ] - ) + nodes.append(node) + print(nodes) + ''' + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', ) - ''' - ''' - HandleResponseNode( - model_response=ModelResponse( - parts=[TextPart(content='Paris')], - model_name='function:model_logic', - timestamp=datetime.datetime(...), - ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', ) - ''' - #> End(data=MarkFinalResult(data='Paris', tool_name=None)) - print(agent_run.data) - #> Paris + ), + End(data=MarkFinalResult(data='Paris', tool_name=None)), + ] + ''' + print(agent_run.data) + #> Paris ``` Args: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 165ed9912..c6775c838 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -28,7 +28,7 @@ class SystemPromptPart: Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information. """ - part_kind: Literal['system-prompt'] = field(default='system-prompt', init=False, repr=False) + part_kind: Literal['system-prompt'] = 'system-prompt' """Part type identifier, this is available on all parts as a discriminator.""" @@ -46,7 +46,7 @@ class UserPromptPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the prompt.""" - part_kind: Literal['user-prompt'] = field(default='user-prompt', init=False, repr=False) + part_kind: Literal['user-prompt'] = 'user-prompt' """Part type identifier, this is available on all parts as a discriminator.""" @@ -69,7 +69,7 @@ class ToolReturnPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the tool returned.""" - part_kind: Literal['tool-return'] = field(default='tool-return', init=False, repr=False) + part_kind: Literal['tool-return'] = 'tool-return' """Part type identifier, this is available on all parts as a discriminator.""" def model_response_str(self) -> str: @@ -123,7 +123,7 @@ class RetryPromptPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the retry was triggered.""" - part_kind: Literal['retry-prompt'] = field(default='retry-prompt', init=False, repr=False) + part_kind: Literal['retry-prompt'] = 'retry-prompt' """Part type identifier, this is available on all parts as a discriminator.""" def model_response(self) -> str: @@ -149,7 +149,7 @@ class ModelRequest: parts: list[ModelRequestPart] """The parts of the user message.""" - kind: Literal['request'] = field(default='request', init=False, repr=False) + kind: Literal['request'] = 'request' """Message type identifier, this is available on all parts as a discriminator.""" @@ -160,7 +160,7 @@ class TextPart: content: str """The text content of the response.""" - part_kind: Literal['text'] = field(default='text', init=False, repr=False) + part_kind: Literal['text'] = 'text' """Part type identifier, this is available on all parts as a discriminator.""" def has_content(self) -> bool: @@ -184,7 +184,7 @@ class ToolCallPart: tool_call_id: str | None = None """Optional tool call identifier, this is used by some models including OpenAI.""" - part_kind: Literal['tool-call'] = field(default='tool-call', init=False, repr=False) + part_kind: Literal['tool-call'] = 'tool-call' """Part type identifier, this is available on all parts as a discriminator.""" def args_as_dict(self) -> dict[str, Any]: @@ -237,7 +237,7 @@ class ModelResponse: If the model provides a timestamp in the response (as OpenAI does) that will be used. """ - kind: Literal['response'] = field(default='response', init=False, repr=False) + kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" @@ -255,7 +255,7 @@ class TextPartDelta: content_delta: str """The incremental text content to add to the existing `TextPart` content.""" - part_delta_kind: Literal['text'] = field(default='text', init=False, repr=False) + part_delta_kind: Literal['text'] = 'text' """Part delta type identifier, used as a discriminator.""" def apply(self, part: ModelResponsePart) -> TextPart: @@ -295,7 +295,7 @@ class ToolCallPartDelta: Note this is never treated as a delta — it can replace None, but otherwise if a non-matching value is provided an error will be raised.""" - part_delta_kind: Literal['tool_call'] = field(default='tool_call', init=False, repr=False) + part_delta_kind: Literal['tool_call'] = 'tool_call' """Part delta type identifier, used as a discriminator.""" def as_part(self) -> ToolCallPart | None: @@ -426,7 +426,7 @@ class PartStartEvent: part: ModelResponsePart """The newly started `ModelResponsePart`.""" - event_kind: Literal['part_start'] = field(default='part_start', init=False, repr=False) + event_kind: Literal['part_start'] = 'part_start' """Event type identifier, used as a discriminator.""" @@ -440,7 +440,7 @@ class PartDeltaEvent: delta: ModelResponsePartDelta """The delta to apply to the specified part.""" - event_kind: Literal['part_delta'] = field(default='part_delta', init=False, repr=False) + event_kind: Literal['part_delta'] = 'part_delta' """Event type identifier, used as a discriminator.""" @@ -456,7 +456,7 @@ class FunctionToolCallEvent: """The (function) tool call to make.""" call_id: str = field(init=False) """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id.""" - event_kind: Literal['function_tool_call'] = field(default='function_tool_call', init=False, repr=False) + event_kind: Literal['function_tool_call'] = 'function_tool_call' """Event type identifier, used as a discriminator.""" def __post_init__(self): @@ -471,7 +471,7 @@ class FunctionToolResultEvent: """The result of the call to the function tool.""" call_id: str """An ID used to match the result to its original call.""" - event_kind: Literal['function_tool_result'] = field(default='function_tool_result', init=False, repr=False) + event_kind: Literal['function_tool_result'] = 'function_tool_result' """Event type identifier, used as a discriminator.""" From 78e85d6542fd06b2a879bac717f122014d20ef86 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:51:54 -0700 Subject: [PATCH 10/28] Introduce auxiliary types --- pydantic_ai_slim/pydantic_ai/agent.py | 252 ++++++++++++---------- pydantic_graph/pydantic_graph/__init__.py | 4 +- pydantic_graph/pydantic_graph/graph.py | 177 ++++++++------- tests/test_agent.py | 4 +- tests/typed_agent.py | 6 +- 5 files changed, 251 insertions(+), 192 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 661b3dfb1..dd0dabbd0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,10 +10,10 @@ from typing import Any, Callable, Generic, cast, final, overload import logfire_api -from typing_extensions import Self, TypeVar, deprecated +from typing_extensions import TypeVar, deprecated -from pydantic_graph import BaseNode, Graph, GraphRun, GraphRunContext -from pydantic_graph.nodes import End +from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRunner +from pydantic_graph.graph import GraphRunResult from . import ( _agent_graph, @@ -215,7 +215,7 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: ... + ) -> AgentRunner[AgentDepsT, ResultDataT]: ... @overload def run( @@ -230,7 +230,7 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, RunResultDataT]: ... + ) -> AgentRunner[AgentDepsT, RunResultDataT]: ... def run( self, @@ -244,7 +244,7 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, Any]: + ) -> AgentRunner[AgentDepsT, Any]: """Run the agent with a user prompt. This method builds an internal agent graph (using system prompts, tools and result schemas) and then @@ -313,7 +313,7 @@ async def main(): End(data=MarkFinalResult(data='Paris', tool_name=None)), ] ''' - print(agent_run.data) + print(agent_run.final_result.data) #> Paris ``` @@ -394,7 +394,7 @@ async def main(): system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) - return AgentRun( + return AgentRunner( graph.run( start_node, state=state, @@ -416,7 +416,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, ResultDataT]: ... + ) -> AgentRunResult[AgentDepsT, ResultDataT]: ... @overload def run_sync( @@ -431,7 +431,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, RunResultDataT]: ... + ) -> AgentRunResult[AgentDepsT, RunResultDataT]: ... def run_sync( self, @@ -445,7 +445,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRun[AgentDepsT, Any]: + ) -> AgentRunResult[AgentDepsT, Any]: """Run the agent with a user prompt synchronously. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. @@ -1113,17 +1113,140 @@ def _prepare_result_schema( return self._result_schema # pyright: ignore[reportReturnType] +@dataclasses.dataclass +class AgentRunner(Generic[AgentDepsT, ResultDataT]): + graph_runner: GraphRunner[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + + def __await__(self) -> Generator[Any, Any, AgentRunResult[AgentDepsT, ResultDataT]]: + """Run the agent graph until it ends, and return self.""" + + async def _run(): + graph_run_result = await self.graph_runner + return AgentRunResult(graph_run_result) + + return _run().__await__() + + def __enter__(self) -> AgentRun[AgentDepsT, ResultDataT]: + graph_run = self.graph_runner.__enter__() + return AgentRun(graph_run) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.graph_runner.__exit__(exc_type, exc_val, exc_tb) + + @dataclasses.dataclass class AgentRun(Generic[AgentDepsT, ResultDataT]): - graph_run: GraphRun[ + """A stateful, iterable run of an agent.""" + + _graph_run: GraphRun[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] ] + @property + def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT]: + graph_run_result = self._graph_run.final_result + return AgentRunResult(graph_run_result) + + def __aiter__( + self, + ) -> AsyncIterator[ + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + MarkFinalResult[ResultDataT], + ] + | End[MarkFinalResult[ResultDataT]] + ]: + return self + + async def __anext__( + self, + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + MarkFinalResult[ResultDataT], + ] + | End[MarkFinalResult[ResultDataT]] + ): + """Use the last returned node as the input to `Graph.next`.""" + return await self._graph_run.__anext__() + + async def next( + self, + node: BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + MarkFinalResult[ResultDataT], + ], + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + MarkFinalResult[ResultDataT], + ] + | End[MarkFinalResult[ResultDataT]] + ): + # TODO: It would be nice to expose a synchronous interface for this, to be able to + # synchronously iterate over the agent graph. I don't think this would be hard to do, + # but I'm having a hard time coming up with an API that fits nicely along side the current `run_sync`. + # The use of `await` provides an easy way to signal that you just want the result, but it's less + # clear to me what the analogous thing should be for synchronous code. + return await self._graph_run.next(node) + + def usage(self) -> _usage.Usage: + return self._graph_run.state.usage + + def get_graph_ctx( + self, + ) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: + return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( + self._graph_run.state, self._graph_run.deps + ) + + +@dataclasses.dataclass +class AgentRunResult(Generic[AgentDepsT, ResultDataT]): + """The final result of an agent run.""" + + graph_run_result: GraphRunResult[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + ] + + @property + def result(self) -> MarkFinalResult[ResultDataT]: + return self.graph_run_result.result + + @property + def data(self) -> ResultDataT: + return self.result.data + + @property + def _result_tool_name(self) -> str | None: + return self.result.tool_name + + def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: + """Set return content for the result tool. + + Useful if you want to continue the conversation and want to set the response to the result tool call. + """ + if not self.result.tool_name: + raise ValueError('Cannot set result tool return content when the return type is `str`.') + messages = deepcopy(self.graph_run_result.state.message_history) + last_message = messages[-1] + for part in last_message.parts: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name: + part.content = return_content + return messages + raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.') + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: if result_tool_return_content is not None: return self._set_result_tool_return(result_tool_return_content) else: - return self.graph_run.state.message_history + return self.graph_run_result.state.message_history def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. @@ -1143,7 +1266,7 @@ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> @property def _new_message_index(self) -> int: - return self.graph_run.deps.new_message_index + return self.graph_run_result.deps.new_message_index def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] @@ -1165,103 +1288,4 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> ) def usage(self) -> _usage.Usage: - return self.graph_run.state.usage - - def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: - """Set return content for the result tool. - - Useful if you want to continue the conversation and want to set the response to the result tool call. - """ - if not self.result.tool_name: - raise ValueError('Cannot set result tool return content when the return type is `str`.') - messages = deepcopy(self.graph_run.state.message_history) - last_message = messages[-1] - for part in last_message.parts: - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name: - part.content = return_content - return messages - raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.') - - @property - def is_ended(self) -> bool: - return self.graph_run.is_ended - - @property - def result(self) -> MarkFinalResult[ResultDataT]: - return self.graph_run.result - - @property - def _result_tool_name(self) -> str | None: - return self.graph_run.result.tool_name - - @property - def data(self) -> ResultDataT: - return self.result.data - - async def next( - self, - node: BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], - ], - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], - ] - | End[MarkFinalResult[ResultDataT]] - ): - # TODO: It would be nice to expose a synchronous interface for this, to be able to - # synchronously iterate over the agent graph. I don't think this would be hard to do, - # but I'm having a hard time coming up with an API that fits nicely along side the current `run_sync`. - # The use of `await` provides an easy way to signal that you just want the result, but it's less - # clear to me what the analogous thing should be for synchronous code. - return await self.graph_run.next(node) - - def __await__(self) -> Generator[Any, Any, Self]: - """Run the agent graph until it ends, and return self.""" - - async def _run(): - await self.graph_run - return self - - return _run().__await__() - - def __enter__(self) -> Self: - self.graph_run.__enter__() - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.graph_run.__exit__(exc_type, exc_val, exc_tb) - - def __aiter__( - self, - ) -> AsyncIterator[ - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], - ] - | End[MarkFinalResult[ResultDataT]] - ]: - return self - - async def __anext__( - self, - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], - ] - | End[MarkFinalResult[ResultDataT]] - ): - """Use the last returned node as the input to `Graph.next`.""" - return await self.graph_run.__anext__() - - def get_graph_ctx( - self, - ) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: - return GraphRunContext(self.graph_run.state, self.graph_run.deps) + return self.graph_run_result.state.usage diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index f5f2a01c0..a0fe8a0fb 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,11 +1,13 @@ from .exceptions import GraphRuntimeError, GraphSetupError -from .graph import Graph, GraphRun +from .graph import Graph, GraphRun, GraphRunner, GraphRunResult from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', 'GraphRun', + 'GraphRunner', + 'GraphRunResult', 'BaseNode', 'End', 'GraphRunContext', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 3a27ca8b5..3fce05723 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -30,7 +30,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) -__all__ = ('Graph', 'GraphRun') +__all__ = ('Graph', 'GraphRun', 'GraphRunner', 'GraphRunResult') _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') @@ -134,7 +134,7 @@ def run( deps: DepsT = None, infer_name: bool = True, span: LogfireSpan | None = None, - ) -> GraphRun[StateT, DepsT, T]: + ) -> GraphRunner[StateT, DepsT, T]: """Run the graph from a starting node until it ends. The returned GraphRun can be awaited (or used as an async iterator) to drive the graph to completion. @@ -175,7 +175,7 @@ async def main(): if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return GraphRun[StateT, DepsT, T]( + return GraphRunner[StateT, DepsT, T]( self, start_node, history=[], @@ -192,7 +192,7 @@ def run_sync( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> GraphRun[StateT, DepsT, T]: + ) -> GraphRunResult[StateT, DepsT, T]: """Run the graph synchronously. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. @@ -510,18 +510,13 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: return -class GraphRun(Generic[StateT, DepsT, RunEndT]): - """A stateful run of a graph. - - After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. - - TODO: this requires some heavy weight API documentation. - """ +class GraphRunner(Generic[StateT, DepsT, RunEndT]): + """This object _MUST_ either be awaited or entered as a contextmanager to get a GraphRun.""" def __init__( self, graph: Graph[StateT, DepsT, RunEndT], - first_node: BaseNode[StateT, DepsT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, history: list[HistoryStep[StateT, RunEndT]], state: StateT, @@ -529,97 +524,135 @@ def __init__( auto_instrument: bool, span: LogfireSpan | None = None, ): - self.graph = graph - self.history = history - self.state = state - self.deps = deps + self._graph = graph + self._start_node = start_node + self._history = history + self._state = state + self._deps = deps self._auto_instrument = auto_instrument self._span = span - self._next_node = first_node - self._started: bool = False - self._result: End[RunEndT] | None = None - - @property - def is_ended(self) -> bool: - return self._result is not None - - @property - def result(self) -> RunEndT: - if self._result is None: - if self._started: - raise exceptions.GraphRuntimeError( - 'This GraphRun has not yet ended. Continue iterating with `async for` or `GraphRun.next`' - ' to complete the run before accessing the result.' - ) - else: - raise exceptions.GraphRuntimeError( - 'This GraphRun has not been started. Did you forget to `await` the run?' - ) - return self._result.data + self._entered = False - async def next( - self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] - ) -> BaseNode[StateT, DepsT, T] | End[T]: - """Note: this method behaves very similarly to an async generator's `asend` method.""" - if not self._started: - raise exceptions.GraphRuntimeError( - 'You must enter the GraphRun as a contextmanager (using `with ...`)' - ' before you can iterate over it or call `next` on it.' - ) - - history = self.history - state = self.state - deps = self.deps - - next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) - - if isinstance(next_node, End): - self._result = next_node - else: - self._next_node = next_node - return next_node - - def __await__(self) -> Generator[Any, Any, typing_extensions.Self]: + def __await__(self) -> Generator[Any, Any, GraphRunResult[StateT, DepsT, RunEndT]]: """Run the graph until it ends, and return the final result.""" - async def _run() -> typing_extensions.Self: - with self: - async for _next_node in self: + async def _run() -> GraphRunResult[StateT, DepsT, RunEndT]: + with self as graph_run: + async for _node in graph_run: pass - return self + return graph_run.final_result return _run().__await__() - def __enter__(self) -> typing_extensions.Self: - """Open a span for the graph run. + def __enter__(self) -> GraphRun[StateT, DepsT, RunEndT]: + """Obtain an iterable graph run. Note that we _require_ that the graph run is used as a context manager when iterating over nodes so that we can ensure that the span covers the time range during which the iteration happens. """ - if self._started: - raise exceptions.GraphRuntimeError('A GraphRun can only be started once.') + if self._entered: + raise exceptions.GraphRuntimeError('A GraphRunner should only be entered once.') if self._auto_instrument and self._span is None: - self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) + self._span = logfire_api.span('run graph {graph.name}', graph=self._graph) if self._span is not None: self._span.__enter__() - self._started = True - return self + self._entered = True + + return GraphRun( + self._graph, + self._start_node, + history=self._history, + state=self._state, + deps=self._deps, + auto_instrument=self._auto_instrument, + span=self._span, + ) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._span is not None: self._span.__exit__(exc_type, exc_val, exc_tb) self._span = None # make it more obvious if you try to use it after exiting + +class GraphRun(Generic[StateT, DepsT, RunEndT]): + """A stateful, iterable run of a graph. + + After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. + + TODO: this requires some heavy weight API documentation. + """ + + def __init__( + self, + graph: Graph[StateT, DepsT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], + *, + history: list[HistoryStep[StateT, RunEndT]], + state: StateT, + deps: DepsT, + auto_instrument: bool, + span: LogfireSpan | None = None, + ): + self.graph = graph + self.history = history + self.state = state + self.deps = deps + self._auto_instrument = auto_instrument + self._span = span + + self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node + + @property + def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT]: + if not isinstance(self._next_node, End): + raise exceptions.GraphRuntimeError('This GraphRun has not finished running.') + return GraphRunResult( + self._next_node.data, graph=self.graph, history=self.history, state=self.state, deps=self.deps + ) + + async def next( + self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] + ) -> BaseNode[StateT, DepsT, T] | End[T]: + """Note: this method behaves very similarly to an async generator's `asend` method.""" + # TODO: replace the End[T] return with a RunResult[T] type which includes extra data. + + history = self.history + state = self.state + deps = self.deps + + self._next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) + + return self._next_node + def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: return self async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" - if self._result: + if isinstance(self._next_node, End): raise StopAsyncIteration return await self.next(self._next_node) + + +class GraphRunResult(Generic[StateT, DepsT, RunEndT]): + """The final result of running a graph.""" + + def __init__( + self, + result: RunEndT, + *, + graph: Graph[StateT, DepsT, RunEndT], + history: list[HistoryStep[StateT, RunEndT]], + state: StateT, + deps: DepsT, + ): + self.result = result + self.graph = graph + self.history = history + self.state = state + self.deps = deps diff --git a/tests/test_agent.py b/tests/test_agent.py index f5f87834c..e2ca4d018 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -279,7 +279,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) - assert result.graph_run.result.tool_name == 'final_result' + assert result.result.tool_name == 'final_result' assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] @@ -312,7 +312,7 @@ def test_result_tool_return_content_no_tool(): result = agent.run_sync('Hello') assert result.data == 0 - result.graph_run.result.tool_name = 'wrong' + result.result.tool_name = 'wrong' with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(result_tool_return_content='foobar') diff --git a/tests/typed_agent.py b/tests/typed_agent.py index dbbe04411..6e6607325 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -8,7 +8,7 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai.agent import AgentRun +from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition @@ -139,7 +139,7 @@ async def result_validator_wrong(ctx: RunContext[int], result: str) -> str: def run_sync() -> None: result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) - assert_type(result, AgentRun[MyDeps, str]) + assert_type(result, AgentRunResult[MyDeps, str]) assert_type(result.data, str) @@ -176,7 +176,7 @@ class Bar: def run_sync3() -> None: result = union_agent.run_sync('testing') - assert_type(result, AgentRun[None, Union[Foo, Bar]]) + assert_type(result, AgentRunResult[None, Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar]) From ef8895a4275df0332bb1044097cd7b0f3ef2de21 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 17:36:21 -0700 Subject: [PATCH 11/28] Address some feedback --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 3 +- pydantic_ai_slim/pydantic_ai/agent.py | 105 ++++++++++++++++++- pydantic_graph/pydantic_graph/graph.py | 20 ++-- 3 files changed, 109 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a7ef0a1b0..63ed443ea 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -325,8 +325,7 @@ async def run( async with self.stream(ctx): pass - # the stream should set `self._next_node` before it ends: - assert (next_node := self._next_node) is not None + assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends' return next_node @asynccontextmanager diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index dd0dabbd0..f2616d4da 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -7,12 +7,12 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from copy import deepcopy from types import FrameType -from typing import Any, Callable, Generic, cast, final, overload +from typing import Any, Callable, Generic, NoReturn, cast, final, overload import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRunner +from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRunner, GraphRuntimeError from pydantic_graph.graph import GraphRunResult from . import ( @@ -1136,7 +1136,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.graph_runner.__exit__(exc_type, exc_val, exc_tb) -@dataclasses.dataclass +@dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): """A stateful, iterable run of an agent.""" @@ -1206,6 +1206,15 @@ def get_graph_ctx( self._graph_run.state, self._graph_run.deps ) + def __repr__(self): + try: + result_repr = repr(self._graph_run.final_result.result) + except GraphRuntimeError: + result_repr = '' + usage = self.usage() + kws = [f'result={result_repr}', f'usage={usage}'] + return '<{} {}>'.format(type(self).__name__, ' '.join(kws)) + @dataclasses.dataclass class AgentRunResult(Generic[AgentDepsT, ResultDataT]): @@ -1289,3 +1298,93 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> def usage(self) -> _usage.Usage: return self.graph_run_result.state.usage + + +# PyCharm doesn't respect `if not TYPE_CHECKING:`, so it's harder to add behaviors that we +# don't want picked up by type-checking. In order to ensure that PyCharm doesn't think that +# it's allowed to iterate over an `AgentRunResult`, we add this `__aiter__` implementation +# via setattr on the class, which ensures it is not caught by any type-checkers. + + +def _agent_run_result_aiter(self: AgentRunResult[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `async for result in await agent.run(...)`?\n' + 'If so, you need to drop the `await` keyword and use `with` to access the agent run.\n' + 'You can fix this error by changing `async for result in await agent.run(...):` to \n' + '\n' + 'with agent.run(...) as agent_run:\n' + ' async for node in agent_run:\n' + ' ...' + ) from TypeError(f"'async for' requires an object with __aiter__ method, got {type(self).__name__}") + + +def _agent_run_result_iter(self: AgentRunResult[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `for result in await agent.run(...)`?\n' + 'If so, you need to drop the `await` keyword, use `with` to access the agent run, and use `async for` to iterate.\n' + 'You can fix this by changing `for result in await agent.run(...):` to \n' + '\n' + 'with agent.run(...) as agent_run:\n' + ' async for node in agent_run:\n' + ' ...' + ) from TypeError(f"'{type(self).__name__}' object is not iterable") + + +def _agent_run_result_aenter(self: AgentRunResult[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `async with await agent.run(...):`?\n' + 'If so, you need to drop the `await` keyword and drop the `async` in `async with`.\n' + 'You can fix this error by changing `async with await agent.run(...):` to `with agent.run(...):`.' + ) from AttributeError('__aenter__') + + +def _agent_run_result_enter(self: AgentRunResult[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `with await agent.run(...):`?\n' + 'If so, you need to drop the `await` keyword.\n' + 'You can fix this error by changing `with await agent.run(...):` to `with agent.run(...):`.' + ) from AttributeError('__enter__') + + +def _agent_runner_aiter(self: AgentRunner[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `async for node in agent.run(...):`?\n' + 'If so, you need to use `with` to access the agent run.\n' + 'You can fix this error by changing `async for node in agent.run(...):` to \n' + '\n' + 'with agent.run(...) as agent_run:\n' + ' async for result in agent_run:\n' + ' ...' + ) from TypeError(f"'async for' requires an object with __aiter__ method, got {type(self).__name__}") + + +def _agent_runner_iter(self: AgentRunner[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `for node in agent.run(...):`?\n' + 'If so, you need to use `with` to access the agent run, and `async for` to iterate over it.\n' + 'You can fix this error by changing `for node in agent.run(...):` to \n' + '\n' + 'with agent.run(...) as agent_run:\n' + ' async for result in agent_run:\n' + ' ...' + ) from TypeError(f"'{type(self).__name__}' object is not iterable") + + +def _agent_runner_aenter(self: AgentRunner[Any, Any]) -> NoReturn: + raise TypeError( + 'Did you try `async with agent.run(...):`?\n' + 'If so, you need to drop the `async` in `async with`.\n' + 'You can fix this error by changing `async with agent.run(...):` to `with agent.run(...):`.' + ) from AttributeError('__aenter__') + + +setattr(AgentRunResult, '__aiter__', _agent_run_result_aiter) +setattr(AgentRunResult, '__iter__', _agent_run_result_iter) +setattr(AgentRunResult, '__aenter__', _agent_run_result_aenter) +setattr(AgentRunResult, '__aexit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] +setattr(AgentRunResult, '__enter__', _agent_run_result_enter) +setattr(AgentRunResult, '__exit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] +setattr(AgentRunner, '__aiter__', _agent_runner_aiter) +setattr(AgentRunner, '__iter__', _agent_runner_iter) +setattr(AgentRunner, '__aenter__', _agent_runner_aenter) +setattr(AgentRunner, '__aexit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 3fce05723..269f875ce 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -639,20 +639,12 @@ async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: return await self.next(self._next_node) +@dataclass class GraphRunResult(Generic[StateT, DepsT, RunEndT]): """The final result of running a graph.""" - def __init__( - self, - result: RunEndT, - *, - graph: Graph[StateT, DepsT, RunEndT], - history: list[HistoryStep[StateT, RunEndT]], - state: StateT, - deps: DepsT, - ): - self.result = result - self.graph = graph - self.history = history - self.state = state - self.deps = deps + result: RunEndT + graph: Graph[StateT, DepsT, RunEndT] + history: list[HistoryStep[StateT, RunEndT]] + state: StateT + deps: DepsT From 13e3b86d7b0bc6267cb1a3f4984ed79d2cf4a717 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 17:41:23 -0700 Subject: [PATCH 12/28] result -> node --- pydantic_ai_slim/pydantic_ai/agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f2616d4da..e8e29a2b0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1308,9 +1308,9 @@ def usage(self) -> _usage.Usage: def _agent_run_result_aiter(self: AgentRunResult[Any, Any]) -> NoReturn: raise TypeError( - 'Did you try `async for result in await agent.run(...)`?\n' + 'Did you try `async for node in await agent.run(...)`?\n' 'If so, you need to drop the `await` keyword and use `with` to access the agent run.\n' - 'You can fix this error by changing `async for result in await agent.run(...):` to \n' + 'You can fix this error by changing `async for node in await agent.run(...):` to \n' '\n' 'with agent.run(...) as agent_run:\n' ' async for node in agent_run:\n' @@ -1320,9 +1320,9 @@ def _agent_run_result_aiter(self: AgentRunResult[Any, Any]) -> NoReturn: def _agent_run_result_iter(self: AgentRunResult[Any, Any]) -> NoReturn: raise TypeError( - 'Did you try `for result in await agent.run(...)`?\n' + 'Did you try `for node in await agent.run(...)`?\n' 'If so, you need to drop the `await` keyword, use `with` to access the agent run, and use `async for` to iterate.\n' - 'You can fix this by changing `for result in await agent.run(...):` to \n' + 'You can fix this by changing `for node in await agent.run(...):` to \n' '\n' 'with agent.run(...) as agent_run:\n' ' async for node in agent_run:\n' @@ -1353,7 +1353,7 @@ def _agent_runner_aiter(self: AgentRunner[Any, Any]) -> NoReturn: 'You can fix this error by changing `async for node in agent.run(...):` to \n' '\n' 'with agent.run(...) as agent_run:\n' - ' async for result in agent_run:\n' + ' async for node in agent_run:\n' ' ...' ) from TypeError(f"'async for' requires an object with __aiter__ method, got {type(self).__name__}") @@ -1365,7 +1365,7 @@ def _agent_runner_iter(self: AgentRunner[Any, Any]) -> NoReturn: 'You can fix this error by changing `for node in agent.run(...):` to \n' '\n' 'with agent.run(...) as agent_run:\n' - ' async for result in agent_run:\n' + ' async for node in agent_run:\n' ' ...' ) from TypeError(f"'{type(self).__name__}' object is not iterable") From a08aafa5faafa249731483ec5dce5bd13b6bda06 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 17:42:31 -0700 Subject: [PATCH 13/28] Rename MarkFinalResult to FinalResult --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 24 +++++++------- pydantic_ai_slim/pydantic_ai/agent.py | 34 ++++++++++---------- pydantic_ai_slim/pydantic_ai/result.py | 2 +- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 63ed443ea..b42aecbbc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -24,7 +24,7 @@ usage as _usage, ) from .models import ModelRequestParameters, StreamedResponse -from .result import MarkFinalResult, ResultDataT +from .result import FinalResult, ResultDataT from .settings import ModelSettings, merge_model_settings from .tools import ( RunContext, @@ -314,14 +314,14 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N model_response: _messages.ModelResponse _stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) - _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[MarkFinalResult[NodeRunEndT]] | None = field( + _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[MarkFinalResult[NodeRunEndT]]]: # noqa UP007 + ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[FinalResult[NodeRunEndT]]]: # noqa UP007 async with self.stream(ctx): pass @@ -398,7 +398,7 @@ async def _handle_tool_calls( result_schema = ctx.deps.result_schema # first look for the result tool call - final_result: MarkFinalResult[NodeRunEndT] | None = None + final_result: FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] if result_schema is not None: if match := result_schema.find_tool(tool_calls): @@ -412,7 +412,7 @@ async def _handle_tool_calls( ctx.state.increment_retries(ctx.deps.max_result_retries) parts.append(e.tool_retry) else: - final_result = MarkFinalResult(result_data, call.tool_name) + final_result = FinalResult(result_data, call.tool_name) # Then build the other request parts based on end strategy tool_responses: list[_messages.ModelRequestPart] = self._tool_responses @@ -431,9 +431,9 @@ async def _handle_tool_calls( def _handle_final_result( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - final_result: MarkFinalResult[NodeRunEndT], + final_result: FinalResult[NodeRunEndT], tool_responses: list[_messages.ModelRequestPart], - ) -> End[MarkFinalResult[NodeRunEndT]]: + ) -> End[FinalResult[NodeRunEndT]]: run_span = ctx.deps.run_span usage = ctx.state.usage messages = ctx.state.message_history @@ -452,7 +452,7 @@ async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], texts: list[str], - ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[MarkFinalResult[NodeRunEndT]]: + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[FinalResult[NodeRunEndT]]: result_schema = ctx.deps.result_schema text = '\n\n'.join(texts) @@ -465,7 +465,7 @@ async def _handle_text_response( return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: # The following cast is safe because we know `str` is an allowed result type - return self._handle_final_result(ctx, MarkFinalResult(result_data, tool_name=None), []) + return self._handle_final_result(ctx, FinalResult(result_data, tool_name=None), []) else: ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT]( @@ -656,18 +656,18 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]: +) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], FinalResult[ResultT]]: # We'll define the known node classes: nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], HandleResponseNode[DepsT], ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]( + graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], FinalResult[ResultT]]( nodes=nodes, name=name or 'Agent', state_type=GraphAgentState, - run_end_type=MarkFinalResult[result_type], + run_end_type=FinalResult[result_type], auto_instrument=False, ) return graph diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e8e29a2b0..2903648ae 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -27,7 +27,7 @@ usage as _usage, ) from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export -from .result import MarkFinalResult, ResultDataT, StreamedRunResult +from .result import FinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -310,7 +310,7 @@ async def main(): kind='response', ) ), - End(data=MarkFinalResult(data='Paris', tool_name=None)), + End(data=FinalResult(data='Paris', tool_name=None)), ] ''' print(agent_run.final_result.data) @@ -594,18 +594,18 @@ async def main(): async def stream_to_final( s: models.StreamedResponse, - ) -> MarkFinalResult[models.StreamedResponse] | None: + ) -> FinalResult[models.StreamedResponse] | None: result_schema = graph_ctx.deps.result_schema async for maybe_part_event in streamed_response: if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart): if _agent_graph.allow_text_result(result_schema): - return MarkFinalResult(s, None) + return FinalResult(s, None) elif isinstance(new_part, _messages.ToolCallPart): if result_schema is not None and (match := result_schema.find_tool([new_part])): call, _ = match - return MarkFinalResult(s, call.tool_name) + return FinalResult(s, call.tool_name) return None final_result_details = await stream_to_final(streamed_response) @@ -1097,7 +1097,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: def _build_graph( self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[Any]]: + ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]]: return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) def _prepare_result_schema( @@ -1116,7 +1116,7 @@ def _prepare_result_schema( @dataclasses.dataclass class AgentRunner(Generic[AgentDepsT, ResultDataT]): graph_runner: GraphRunner[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] def __await__(self) -> Generator[Any, Any, AgentRunResult[AgentDepsT, ResultDataT]]: @@ -1141,7 +1141,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]): """A stateful, iterable run of an agent.""" _graph_run: GraphRun[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] @property @@ -1155,9 +1155,9 @@ def __aiter__( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], + FinalResult[ResultDataT], ] - | End[MarkFinalResult[ResultDataT]] + | End[FinalResult[ResultDataT]] ]: return self @@ -1167,9 +1167,9 @@ async def __anext__( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], + FinalResult[ResultDataT], ] - | End[MarkFinalResult[ResultDataT]] + | End[FinalResult[ResultDataT]] ): """Use the last returned node as the input to `Graph.next`.""" return await self._graph_run.__anext__() @@ -1179,15 +1179,15 @@ async def next( node: BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], + FinalResult[ResultDataT], ], ) -> ( BaseNode[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], - MarkFinalResult[ResultDataT], + FinalResult[ResultDataT], ] - | End[MarkFinalResult[ResultDataT]] + | End[FinalResult[ResultDataT]] ): # TODO: It would be nice to expose a synchronous interface for this, to be able to # synchronously iterate over the agent graph. I don't think this would be hard to do, @@ -1221,11 +1221,11 @@ class AgentRunResult(Generic[AgentDepsT, ResultDataT]): """The final result of an agent run.""" graph_run_result: GraphRunResult[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT] + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] @property - def result(self) -> MarkFinalResult[ResultDataT]: + def result(self) -> FinalResult[ResultDataT]: return self.graph_run_result.result @property diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 0dab89840..c569d1bb5 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -292,7 +292,7 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: @dataclass -class MarkFinalResult(Generic[ResultDataT]): +class FinalResult(Generic[ResultDataT]): """Marker class to indicate that the result is the final result. This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. From ff6f699221d2114ae026da8a2bfd7a006a75c054 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 19:20:32 -0700 Subject: [PATCH 14/28] Remove GraphRunner/AgentRunner and add .iter() API --- pydantic_ai_slim/pydantic_ai/agent.py | 232 ++++++++-------------- pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 141 +++++-------- 3 files changed, 136 insertions(+), 240 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2903648ae..2e2695c7b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -3,16 +3,16 @@ import asyncio import dataclasses import inspect -from collections.abc import AsyncIterator, Awaitable, Generator, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from copy import deepcopy from types import FrameType -from typing import Any, Callable, Generic, NoReturn, cast, final, overload +from typing import Any, Callable, Generic, cast, final, overload import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRunner, GraphRuntimeError +from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRuntimeError from pydantic_graph.graph import GraphRunResult from . import ( @@ -203,7 +203,7 @@ def __init__( self._register_tool(Tool(tool)) @overload - def run( + async def run( self, user_prompt: str, *, @@ -215,10 +215,10 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunner[AgentDepsT, ResultDataT]: ... + ) -> AgentRunResult[AgentDepsT, ResultDataT]: ... @overload - def run( + async def run( self, user_prompt: str, *, @@ -230,9 +230,9 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunner[AgentDepsT, RunResultDataT]: ... + ) -> AgentRunResult[AgentDepsT, RunResultDataT]: ... - def run( + async def run( self, user_prompt: str, *, @@ -244,27 +244,10 @@ def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunner[AgentDepsT, Any]: + ) -> AgentRunResult[AgentDepsT, Any]: """Run the agent with a user prompt. - This method builds an internal agent graph (using system prompts, tools and result schemas) and then - returns an `AgentRun` object. The `AgentRun` functions as a handle that can be used to iterate over the graph - and obtain the final result. The AgentRun also provides methods to access the full message history, - new messages, and usage statistics. - - The returned `AgentRun` object should always be immediately used in one of two ways: - * Via `await` (i.e., `await agent.run(...)`), which will execute the graph run and return the final result - * This is the API you should use if you just want the end result and are not interested in the execution details - or consuming streaming updates - * As a context manager (i.e., `with agent.run(...) as agent_run:`), which will return an async iterator over - the graph nodes, and which can also be used as an async generator to override the execution of the graph. - * This is the API you should use if you want to consume the graph execution in a streaming manner, - or if you want to consume the stream of events coming from individual requests to the LLM, or the stream - of events coming from the execution of tools. - - For more details, see the documentation of `AgentRun`. - - Example (with `await`): + Example: ```python from pydantic_ai import Agent @@ -276,7 +259,66 @@ async def main(): #> Paris ``` - Example (with a context manager): + Args: + result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no + result validators since result validators would expect an argument that matches the agent's result type. + user_prompt: User input to start/continue the conversation. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + + Returns: + The result of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + with self.iter( + user_prompt=user_prompt, + result_type=result_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + ) as agent_run: + async for _ in agent_run: + pass + return agent_run.final_result + + @contextmanager + def iter( + self, + user_prompt: str, + *, + result_type: type[RunResultDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + ) -> Iterator[AgentRun[AgentDepsT, Any]]: + """Get an AgentRun for the agent with a user prompt which can be iterated over. + + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + returns an `AgentRun` object. The `AgentRun` functions as a handle that can be used to iterate over the graph + and obtain the final result. The AgentRun also provides methods to access the full message history, + new messages, and usage statistics. + + The returned `AgentRun` object can be async iterated over to get the nodes of the graph as they are executed.: + This is the API you should use if you want to consume the graph execution in a streaming manner, + or if you want to consume the stream of events coming from individual requests to the LLM, or the stream + of events coming from the execution of tools. + + For more details, see the documentation of `AgentRun`. + + Example: ```python from pydantic_ai import Agent @@ -284,7 +326,7 @@ async def main(): async def main(): nodes = [] - with agent.run('What is the capital of France?') as agent_run: + with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) @@ -394,15 +436,14 @@ async def main(): system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) - return AgentRunner( - graph.run( - start_node, - state=state, - deps=graph_deps, - infer_name=False, - span=run_span, - ) - ) + with graph.iter( + start_node, + state=state, + deps=graph_deps, + infer_name=False, + span=run_span, + ) as graph_run: + yield AgentRun(graph_run) @overload def run_sync( @@ -572,7 +613,7 @@ async def main(): self._infer_name(frame.f_back) yielded = False - with self.run( + with self.iter( user_prompt, result_type=result_type, message_history=message_history, @@ -1113,29 +1154,6 @@ def _prepare_result_schema( return self._result_schema # pyright: ignore[reportReturnType] -@dataclasses.dataclass -class AgentRunner(Generic[AgentDepsT, ResultDataT]): - graph_runner: GraphRunner[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] - ] - - def __await__(self) -> Generator[Any, Any, AgentRunResult[AgentDepsT, ResultDataT]]: - """Run the agent graph until it ends, and return self.""" - - async def _run(): - graph_run_result = await self.graph_runner - return AgentRunResult(graph_run_result) - - return _run().__await__() - - def __enter__(self) -> AgentRun[AgentDepsT, ResultDataT]: - graph_run = self.graph_runner.__enter__() - return AgentRun(graph_run) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.graph_runner.__exit__(exc_type, exc_val, exc_tb) - - @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): """A stateful, iterable run of an agent.""" @@ -1298,93 +1316,3 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> def usage(self) -> _usage.Usage: return self.graph_run_result.state.usage - - -# PyCharm doesn't respect `if not TYPE_CHECKING:`, so it's harder to add behaviors that we -# don't want picked up by type-checking. In order to ensure that PyCharm doesn't think that -# it's allowed to iterate over an `AgentRunResult`, we add this `__aiter__` implementation -# via setattr on the class, which ensures it is not caught by any type-checkers. - - -def _agent_run_result_aiter(self: AgentRunResult[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `async for node in await agent.run(...)`?\n' - 'If so, you need to drop the `await` keyword and use `with` to access the agent run.\n' - 'You can fix this error by changing `async for node in await agent.run(...):` to \n' - '\n' - 'with agent.run(...) as agent_run:\n' - ' async for node in agent_run:\n' - ' ...' - ) from TypeError(f"'async for' requires an object with __aiter__ method, got {type(self).__name__}") - - -def _agent_run_result_iter(self: AgentRunResult[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `for node in await agent.run(...)`?\n' - 'If so, you need to drop the `await` keyword, use `with` to access the agent run, and use `async for` to iterate.\n' - 'You can fix this by changing `for node in await agent.run(...):` to \n' - '\n' - 'with agent.run(...) as agent_run:\n' - ' async for node in agent_run:\n' - ' ...' - ) from TypeError(f"'{type(self).__name__}' object is not iterable") - - -def _agent_run_result_aenter(self: AgentRunResult[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `async with await agent.run(...):`?\n' - 'If so, you need to drop the `await` keyword and drop the `async` in `async with`.\n' - 'You can fix this error by changing `async with await agent.run(...):` to `with agent.run(...):`.' - ) from AttributeError('__aenter__') - - -def _agent_run_result_enter(self: AgentRunResult[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `with await agent.run(...):`?\n' - 'If so, you need to drop the `await` keyword.\n' - 'You can fix this error by changing `with await agent.run(...):` to `with agent.run(...):`.' - ) from AttributeError('__enter__') - - -def _agent_runner_aiter(self: AgentRunner[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `async for node in agent.run(...):`?\n' - 'If so, you need to use `with` to access the agent run.\n' - 'You can fix this error by changing `async for node in agent.run(...):` to \n' - '\n' - 'with agent.run(...) as agent_run:\n' - ' async for node in agent_run:\n' - ' ...' - ) from TypeError(f"'async for' requires an object with __aiter__ method, got {type(self).__name__}") - - -def _agent_runner_iter(self: AgentRunner[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `for node in agent.run(...):`?\n' - 'If so, you need to use `with` to access the agent run, and `async for` to iterate over it.\n' - 'You can fix this error by changing `for node in agent.run(...):` to \n' - '\n' - 'with agent.run(...) as agent_run:\n' - ' async for node in agent_run:\n' - ' ...' - ) from TypeError(f"'{type(self).__name__}' object is not iterable") - - -def _agent_runner_aenter(self: AgentRunner[Any, Any]) -> NoReturn: - raise TypeError( - 'Did you try `async with agent.run(...):`?\n' - 'If so, you need to drop the `async` in `async with`.\n' - 'You can fix this error by changing `async with agent.run(...):` to `with agent.run(...):`.' - ) from AttributeError('__aenter__') - - -setattr(AgentRunResult, '__aiter__', _agent_run_result_aiter) -setattr(AgentRunResult, '__iter__', _agent_run_result_iter) -setattr(AgentRunResult, '__aenter__', _agent_run_result_aenter) -setattr(AgentRunResult, '__aexit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] -setattr(AgentRunResult, '__enter__', _agent_run_result_enter) -setattr(AgentRunResult, '__exit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] -setattr(AgentRunner, '__aiter__', _agent_runner_aiter) -setattr(AgentRunner, '__iter__', _agent_runner_iter) -setattr(AgentRunner, '__aenter__', _agent_runner_aenter) -setattr(AgentRunner, '__aexit__', lambda *args, **kwargs: None) # pyright: ignore[reportUnknownLambdaType] diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index a0fe8a0fb..079325f59 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,12 +1,11 @@ from .exceptions import GraphRuntimeError, GraphSetupError -from .graph import Graph, GraphRun, GraphRunner, GraphRunResult +from .graph import Graph, GraphRun, GraphRunResult from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', 'GraphRun', - 'GraphRunner', 'GraphRunResult', 'BaseNode', 'End', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 269f875ce..7ea686180 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -3,8 +3,8 @@ import asyncio import inspect import types -from collections.abc import AsyncIterator, Generator, Sequence -from contextlib import ExitStack +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from functools import cached_property from time import perf_counter @@ -30,7 +30,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) -__all__ = ('Graph', 'GraphRun', 'GraphRunner', 'GraphRunResult') +__all__ = ('Graph', 'GraphRun', 'GraphRunResult') _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') @@ -126,7 +126,7 @@ def __init__( self._validate_edges() - def run( + async def run( self: Graph[StateT, DepsT, T], start_node: BaseNode[StateT, DepsT, T], *, @@ -134,12 +134,9 @@ def run( deps: DepsT = None, infer_name: bool = True, span: LogfireSpan | None = None, - ) -> GraphRunner[StateT, DepsT, T]: + ) -> GraphRunResult[StateT, DepsT, T]: """Run the graph from a starting node until it ends. - The returned GraphRun can be awaited (or used as an async iterator) to drive the graph to completion. - TODO: Need to add a more detailed message here explaining that this can behave like a coroutine or a context manager etc. - Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. @@ -175,15 +172,56 @@ async def main(): if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return GraphRunner[StateT, DepsT, T]( - self, - start_node, - history=[], - state=state, - deps=deps, - auto_instrument=self._auto_instrument, - span=span, - ) + with self.iter(start_node, state=state, deps=deps, infer_name=infer_name, span=span) as graph_run: + async for _node in graph_run: + pass + + return graph_run.final_result + + @contextmanager + def iter( + self: Graph[StateT, DepsT, T], + start_node: BaseNode[StateT, DepsT, T], + *, + state: StateT = None, + deps: DepsT = None, + infer_name: bool = True, + span: LogfireSpan | None = None, + ) -> Iterator[GraphRun[StateT, DepsT, T]]: + """A contextmanager that yields a GraphRun that can be async iterated over to drive the graph to completion. + + Args: + start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. + infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a new span will be created. + + Yields: + A GraphRun that can be async iterated over to drive the graph to completion. + + Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: + # TODO: Need to add an example here akin to that from `Graph.run` above + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + if self._auto_instrument and span is None: + span = logfire_api.span('run graph {graph.name}', graph=self) + + with ExitStack() as stack: + if span is not None: + stack.enter_context(span) + yield GraphRun[StateT, DepsT, T]( + self, + start_node, + history=[], + state=state, + deps=deps, + auto_instrument=self._auto_instrument, + span=span, + ) def run_sync( self: Graph[StateT, DepsT, T], @@ -510,75 +548,6 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: return -class GraphRunner(Generic[StateT, DepsT, RunEndT]): - """This object _MUST_ either be awaited or entered as a contextmanager to get a GraphRun.""" - - def __init__( - self, - graph: Graph[StateT, DepsT, RunEndT], - start_node: BaseNode[StateT, DepsT, RunEndT], - *, - history: list[HistoryStep[StateT, RunEndT]], - state: StateT, - deps: DepsT, - auto_instrument: bool, - span: LogfireSpan | None = None, - ): - self._graph = graph - self._start_node = start_node - self._history = history - self._state = state - self._deps = deps - self._auto_instrument = auto_instrument - self._span = span - - self._entered = False - - def __await__(self) -> Generator[Any, Any, GraphRunResult[StateT, DepsT, RunEndT]]: - """Run the graph until it ends, and return the final result.""" - - async def _run() -> GraphRunResult[StateT, DepsT, RunEndT]: - with self as graph_run: - async for _node in graph_run: - pass - - return graph_run.final_result - - return _run().__await__() - - def __enter__(self) -> GraphRun[StateT, DepsT, RunEndT]: - """Obtain an iterable graph run. - - Note that we _require_ that the graph run is used as a context manager when iterating over nodes - so that we can ensure that the span covers the time range during which the iteration happens. - """ - if self._entered: - raise exceptions.GraphRuntimeError('A GraphRunner should only be entered once.') - - if self._auto_instrument and self._span is None: - self._span = logfire_api.span('run graph {graph.name}', graph=self._graph) - - if self._span is not None: - self._span.__enter__() - - self._entered = True - - return GraphRun( - self._graph, - self._start_node, - history=self._history, - state=self._state, - deps=self._deps, - auto_instrument=self._auto_instrument, - span=self._span, - ) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._span is not None: - self._span.__exit__(exc_type, exc_val, exc_tb) - self._span = None # make it more obvious if you try to use it after exiting - - class GraphRun(Generic[StateT, DepsT, RunEndT]): """A stateful, iterable run of a graph. From 41bb069b44514c7b36332fdd282403e56dcc4e4d Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 19:27:35 -0700 Subject: [PATCH 15/28] Make result private --- pydantic_ai_slim/pydantic_ai/agent.py | 12 ++++++------ tests/test_agent.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2e2695c7b..77c032219 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1243,31 +1243,31 @@ class AgentRunResult(Generic[AgentDepsT, ResultDataT]): ] @property - def result(self) -> FinalResult[ResultDataT]: + def _result(self) -> FinalResult[ResultDataT]: return self.graph_run_result.result @property def data(self) -> ResultDataT: - return self.result.data + return self._result.data @property def _result_tool_name(self) -> str | None: - return self.result.tool_name + return self._result.tool_name def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: """Set return content for the result tool. Useful if you want to continue the conversation and want to set the response to the result tool call. """ - if not self.result.tool_name: + if not self._result.tool_name: raise ValueError('Cannot set result tool return content when the return type is `str`.') messages = deepcopy(self.graph_run_result.state.message_history) last_message = messages[-1] for part in last_message.parts: - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result.tool_name: part.content = return_content return messages - raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.') + raise LookupError(f'No tool call found with tool name {self._result.tool_name!r}.') def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: if result_tool_return_content is not None: diff --git a/tests/test_agent.py b/tests/test_agent.py index e2ca4d018..a478b5633 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -279,7 +279,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) - assert result.result.tool_name == 'final_result' + assert result._result.tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] @@ -312,7 +312,7 @@ def test_result_tool_return_content_no_tool(): result = agent.run_sync('Hello') assert result.data == 0 - result.result.tool_name = 'wrong' + result._result.tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(result_tool_return_content='foobar') From b5650885061ac9b406d15ba7d9ff67f5a9bf9db4 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 21:14:03 -0700 Subject: [PATCH 16/28] Reduce diff to main and add some docstrings --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 44 +++--- pydantic_ai_slim/pydantic_ai/agent.py | 76 +++++----- .../pydantic_ai/models/__init__.py | 7 + pydantic_ai_slim/pydantic_ai/result.py | 12 +- pydantic_graph/pydantic_graph/graph.py | 67 +++++++-- pydantic_graph/pydantic_graph/nodes.py | 2 +- pyproject.toml | 2 +- tests/test_agent.py | 50 ++++--- tests/test_parts_manager.py | 133 ++++++++++++------ tests/test_streaming.py | 12 ++ 10 files changed, 264 insertions(+), 141 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b42aecbbc..d1185246c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -21,10 +21,10 @@ exceptions, messages as _messages, models, + result, usage as _usage, ) -from .models import ModelRequestParameters, StreamedResponse -from .result import FinalResult, ResultDataT +from .result import ResultDataT from .settings import ModelSettings, merge_model_settings from .tools import ( RunContext, @@ -228,8 +228,9 @@ async def run( @asynccontextmanager async def stream( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> AsyncIterator[StreamedResponse]: + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]], + ) -> AsyncIterator[models.StreamedResponse]: if self._did_stream: raise exceptions.AgentRunError('stream() can only be called once') @@ -272,7 +273,7 @@ async def _make_request( async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> tuple[ModelSettings | None, ModelRequestParameters]: + ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: ctx.state.message_history.append(self.request) # Check usage @@ -309,19 +310,20 @@ def _finish_handling( @dataclasses.dataclass class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): - """Process the response from a model, decide whether to end the run or make a new request.""" + """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse _stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) - _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[FinalResult[NodeRunEndT]] | None = field( + _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[FinalResult[NodeRunEndT]]]: # noqa UP007 + ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007 + """TODO: Docstring?""" async with self.stream(ctx): pass @@ -332,6 +334,7 @@ async def run( async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: + """TODO: Docstring.""" with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: stream = self._run_stream(ctx) yield stream @@ -398,7 +401,7 @@ async def _handle_tool_calls( result_schema = ctx.deps.result_schema # first look for the result tool call - final_result: FinalResult[NodeRunEndT] | None = None + final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] if result_schema is not None: if match := result_schema.find_tool(tool_calls): @@ -412,7 +415,7 @@ async def _handle_tool_calls( ctx.state.increment_retries(ctx.deps.max_result_retries) parts.append(e.tool_retry) else: - final_result = FinalResult(result_data, call.tool_name) + final_result = result.FinalResult(result_data, call.tool_name) # Then build the other request parts based on end strategy tool_responses: list[_messages.ModelRequestPart] = self._tool_responses @@ -431,9 +434,9 @@ async def _handle_tool_calls( def _handle_final_result( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - final_result: FinalResult[NodeRunEndT], + final_result: result.FinalResult[NodeRunEndT], tool_responses: list[_messages.ModelRequestPart], - ) -> End[FinalResult[NodeRunEndT]]: + ) -> End[result.FinalResult[NodeRunEndT]]: run_span = ctx.deps.run_span usage = ctx.state.usage messages = ctx.state.message_history @@ -452,7 +455,7 @@ async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], texts: list[str], - ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[FinalResult[NodeRunEndT]]: + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: result_schema = ctx.deps.result_schema text = '\n\n'.join(texts) @@ -465,7 +468,7 @@ async def _handle_text_response( return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: # The following cast is safe because we know `str` is an allowed result type - return self._handle_final_result(ctx, FinalResult(result_data, tool_name=None), []) + return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), []) else: ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT]( @@ -480,6 +483,7 @@ async def _handle_text_response( def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: + """TODO: Docstring.""" return RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, @@ -496,11 +500,11 @@ async def process_function_tools( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: - """Process function (non-result) tool calls in parallel. + """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `parts` as an output argument. + Because async iterators can't have return values, we use `output_parts` as an output argument. """ stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early' result_schema = ctx.deps.result_schema @@ -603,6 +607,7 @@ async def _validate_result( def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: + """TODO: Docstring.""" return result_schema is None or result_schema.allow_text_result @@ -656,18 +661,19 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], FinalResult[ResultT]]: +) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]: + """TODO: Docstring.""" # We'll define the known node classes: nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], HandleResponseNode[DepsT], ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], FinalResult[ResultT]]( + graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]( nodes=nodes, name=name or 'Agent', state_type=GraphAgentState, - run_end_type=FinalResult[result_type], + run_end_type=result.FinalResult[result_type], auto_instrument=False, ) return graph diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 77c032219..fb6922098 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -12,7 +12,7 @@ import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext, GraphRuntimeError +from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext from pydantic_graph.graph import GraphRunResult from . import ( @@ -245,7 +245,10 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[AgentDepsT, Any]: - """Run the agent with a user prompt. + """Run the agent with a user prompt in async mode. + + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + runs the graph to completion. The result of the run is returned. Example: ```python @@ -260,9 +263,9 @@ async def main(): ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -288,7 +291,9 @@ async def main(): ) as agent_run: async for _ in agent_run: pass - return agent_run.final_result + final_result = agent_run.final_result + assert final_result is not None, 'The graph run should have ended with a final result' + return final_result @contextmanager def iter( @@ -304,17 +309,15 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, ) -> Iterator[AgentRun[AgentDepsT, Any]]: - """Get an AgentRun for the agent with a user prompt which can be iterated over. + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. - This method builds an internal agent graph (using system prompts, tools and result schemas) and then - returns an `AgentRun` object. The `AgentRun` functions as a handle that can be used to iterate over the graph - and obtain the final result. The AgentRun also provides methods to access the full message history, - new messages, and usage statistics. + This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to (async) iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. - The returned `AgentRun` object can be async iterated over to get the nodes of the graph as they are executed.: - This is the API you should use if you want to consume the graph execution in a streaming manner, - or if you want to consume the stream of events coming from individual requests to the LLM, or the stream - of events coming from the execution of tools. + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. For more details, see the documentation of `AgentRun`. @@ -360,9 +363,9 @@ async def main(): ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -487,7 +490,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[AgentDepsT, Any]: - """Run the agent with a user prompt synchronously. + """Synchronously run the agent with a user prompt. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. @@ -504,9 +507,9 @@ def run_sync( ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -565,7 +568,7 @@ def run_stream( ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ... @asynccontextmanager - async def run_stream( # noqa + async def run_stream( # noqa C901 self, user_prompt: str, *, @@ -593,9 +596,9 @@ async def main(): ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -607,6 +610,8 @@ async def main(): Returns: The result of the run. """ + # TODO: We need to deprecate this now that we have the `iter` method. + # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch @@ -630,7 +635,7 @@ async def main(): while True: if isinstance(node, _agent_graph.ModelRequestNode): node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) - graph_ctx = agent_run.get_graph_ctx() + graph_ctx = agent_run.ctx async with node.stream(graph_ctx) as streamed_response: async def stream_to_final( @@ -1156,15 +1161,28 @@ def _prepare_result_schema( @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): - """A stateful, iterable run of an agent.""" + """A stateful, iterable run of an agent. + + TODO: Add API documentation here. + """ _graph_run: GraphRun[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] @property - def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT]: + def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: + """The current context of the agent run.""" + return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( + self._graph_run.state, self._graph_run.deps + ) + + @property + def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT] | None: + """The final result of the agent run.""" graph_run_result = self._graph_run.final_result + if graph_run_result is None: + return None return AgentRunResult(graph_run_result) def __aiter__( @@ -1217,20 +1235,10 @@ async def next( def usage(self) -> _usage.Usage: return self._graph_run.state.usage - def get_graph_ctx( - self, - ) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: - return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( - self._graph_run.state, self._graph_run.deps - ) - def __repr__(self): - try: - result_repr = repr(self._graph_run.final_result.result) - except GraphRuntimeError: - result_repr = '' - usage = self.usage() - kws = [f'result={result_repr}', f'usage={usage}'] + final_result = self._graph_run.final_result + result_repr = '' if final_result is None else repr(final_result.result) + kws = [f'result={result_repr}', f'usage={self.usage()}'] return '<{} {}>'.format(type(self).__name__, ' '.join(kws)) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 38b45f7e8..edc9315c7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from ..tools import ToolDefinition + KnownModelName = Literal[ 'anthropic:claude-3-5-haiku-latest', 'anthropic:claude-3-5-sonnet-latest', @@ -262,16 +263,20 @@ def timestamp(self) -> datetime: raise NotImplementedError() async def stream_events(self) -> AsyncIterator[ModelResponseStreamEvent]: + """TODO: Docstring.""" return self.__aiter__() async def stream_debounced_events( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[list[ModelResponseStreamEvent]]: + """TODO: Docstring.""" async with _utils.group_by_temporal(self, debounce_by) as group_iter: async for items in group_iter: yield items async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: + """TODO: Docstring.""" + async def _stream_structured_ungrouped() -> AsyncIterator[None]: # yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting async for _event in self: @@ -282,6 +287,8 @@ async def _stream_structured_ungrouped() -> AsyncIterator[None]: yield self.get() # current state of the response async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: + """TODO: Docstring.""" + # Define a "merged" version of the iterator that will yield items that have already been retrieved # and items that we receive while streaming. We define a dedicated async iterator for this so we can # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c569d1bb5..42135e4cc 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -219,7 +219,7 @@ async def stream_structured( msg = self._stream_response.get() yield msg, True - # TODO: Should this now be `final_response` instead of `structured_response`? + lf_span.set_attribute('structured_response', msg) await self._marked_completed(msg) @@ -231,7 +231,6 @@ async def get_data(self) -> ResultDataT: async for _ in usage_checking_stream: pass - message = self._stream_response.get() await self._marked_completed(message) return await self.validate_structured_result(message) @@ -293,17 +292,12 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: @dataclass class FinalResult(Generic[ResultDataT]): - """Marker class to indicate that the result is the final result. - - This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. - - It also avoids problems in the case where the result type is itself `None`, but is set. - """ + """Marker class storing the final result of an agent run and associated metadata.""" data: ResultDataT """The final result data.""" tool_name: str | None - """Name of the final result tool, None if the result is a string.""" + """Name of the final result tool; `None` if the result came from unstructured text content.""" def _get_usage_checking_stream_response( diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 7ea686180..db558cd3a 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -143,10 +143,11 @@ async def run( state: The initial state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. - span: The span to use for the graph run. If not provided, a new span will be created. + span: The span to use for the graph run. If not provided, a span will be created depending on the value of + the `_auto_instrument` field. Returns: - The result type from ending the run and the history of the run. + A `GraphRunResult` containing information about the run, including its final result. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: @@ -155,17 +156,17 @@ async def run( async def main(): state = MyState(1) - graph_run = await never_42_graph.run(Increment(), state=state) + graph_run_result = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) - print(len(graph_run.history)) + print(len(graph_run_result.history)) #> 3 state = MyState(41) - graph_run = await never_42_graph.run(Increment(), state=state) + graph_run_result = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) - print(len(graph_run.history)) + print(len(graph_run_result.history)) #> 5 ``` """ @@ -176,7 +177,9 @@ async def main(): async for _node in graph_run: pass - return graph_run.final_result + final_result = graph_run.final_result + assert final_result is not None, 'GraphRun should have a final result' + return final_result @contextmanager def iter( @@ -188,10 +191,19 @@ def iter( infer_name: bool = True, span: LogfireSpan | None = None, ) -> Iterator[GraphRun[StateT, DepsT, T]]: - """A contextmanager that yields a GraphRun that can be async iterated over to drive the graph to completion. + """A contextmanager which can be used to iterate over the graph's nodes as they are executed. + + This method returns a `GraphRun` object which can be used to (async) iterate over the nodes of this `Graph` as + they are executed. This is the API to use if you want to record or interact with the nodes as the graph + execution unfolds. + + The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once + it has completed. + + For more details, see the documentation of `GraphRun`. Args: - start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. @@ -201,8 +213,32 @@ def iter( Yields: A GraphRun that can be async iterated over to drive the graph to completion. - Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: - # TODO: Need to add an example here akin to that from `Graph.run` above + Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: + + ```py {title="run_never_42.py" noqa="I001" py="3.10"} + from never_42 import Increment, MyState, never_42_graph + + async def main(): + state = MyState(1) + nodes = [] + with never_42_graph.iter(Increment(), state=state) as graph_run: + async for node in graph_run: + nodes.append(node) + print(nodes) + #> [Check42(), End(data=2)] + print(state) + #> MyState(number=2) + + state = MyState(41) + nodes = [] + with never_42_graph.iter(Increment(), state=state) as graph_run: + async for node in graph_run: + nodes.append(node) + print(nodes) + #> [Check42(), Increment(), Check42(), End(data=43)] + print(state) + #> MyState(number=43) + ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) @@ -231,7 +267,7 @@ def run_sync( deps: DepsT = None, infer_name: bool = True, ) -> GraphRunResult[StateT, DepsT, T]: - """Run the graph synchronously. + """Synchronously run the graph. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. @@ -553,7 +589,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]): After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. - TODO: this requires some heavy weight API documentation. + TODO: this requires some heavy weight API documentation. Docstring """ def __init__( @@ -577,9 +613,10 @@ def __init__( self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node @property - def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT]: + def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: + """The final result of the agent run if the run is completed, otherwise `None`.""" if not isinstance(self._next_node, End): - raise exceptions.GraphRuntimeError('This GraphRun has not finished running.') + return None # The GraphRun has not finished running return GraphRunResult( self._next_node.data, graph=self.graph, history=self.history, state=self.state, deps=self.deps ) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index c50a63c21..f28106c97 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -28,7 +28,7 @@ class GraphRunContext(Generic[StateT, DepsT]): """Context for a graph.""" - # TODO: It would be nice to get rid of this struct and just pass both these things around... + # TODO: Can we get rid of this struct and just pass both these things around..? state: StateT """The state of the graph.""" diff --git a/pyproject.toml b/pyproject.toml index fb546914f..6782c9092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' - ignore-words-list = 'asend' +ignore-words-list = 'asend' diff --git a/tests/test_agent.py b/tests/test_agent.py index a478b5633..a72cf8499 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -565,7 +565,6 @@ async def ret_a(x: str) -> str: assert result2.usage() == snapshot( Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) - new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -1340,15 +1339,17 @@ async def func() -> str: [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar'), - SystemPromptPart(content=dynamic_value), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart(content=dynamic_value, part_kind='system-prompt'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ] ) @@ -1361,25 +1362,30 @@ async def func() -> str: [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar'), + SystemPromptPart(content='Foobar', part_kind='system-prompt'), SystemPromptPart( content='A', # Remains the same + part_kind='system-prompt', ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ModelRequest( - parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))], + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ] ) @@ -1405,18 +1411,21 @@ async def func(): [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar'), + SystemPromptPart(content='Foobar', part_kind='system-prompt'), SystemPromptPart( content=dynamic_value, + part_kind='system-prompt', dynamic_ref=func.__qualname__, ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ] ) @@ -1429,26 +1438,31 @@ async def func(): [ ModelRequest( parts=[ - SystemPromptPart(content='Foobar'), + SystemPromptPart(content='Foobar', part_kind='system-prompt'), SystemPromptPart( content='B', + part_kind='system-prompt', dynamic_ref=func.__qualname__, ), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ModelRequest( - parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))], + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], + kind='request', ), ModelResponse( - parts=[TextPart(content='success (no tool calls)')], + parts=[TextPart(content='success (no tool calls)', part_kind='text')], model_name='test', timestamp=IsNow(tz=timezone.utc), + kind='response', ), ] ) diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py index 13cb81235..a93e7fb31 100644 --- a/tests/test_parts_manager.py +++ b/tests/test_parts_manager.py @@ -24,32 +24,56 @@ def test_handle_text_deltas(vendor_part_id: str | None): assert manager.get_parts() == [] event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') - assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) - assert manager.get_parts() == snapshot([TextPart(content='hello ')]) + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') - assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) - assert manager.get_parts() == snapshot([TextPart(content='hello world')]) + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot([TextPart(content='hello world', part_kind='text')]) def test_handle_dovetailed_text_deltas(): manager = ModelResponsePartsManager() event = manager.handle_text_delta(vendor_part_id='first', content='hello ') - assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) - assert manager.get_parts() == snapshot([TextPart(content='hello ')]) + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ') - assert event == snapshot(PartStartEvent(index=1, part=TextPart(content='goodbye '))) - assert manager.get_parts() == snapshot([TextPart(content='hello '), TextPart(content='goodbye ')]) + assert event == snapshot( + PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] + ) event = manager.handle_text_delta(vendor_part_id='first', content='world') - assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) - assert manager.get_parts() == snapshot([TextPart(content='hello world'), TextPart(content='goodbye ')]) + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] + ) event = manager.handle_text_delta(vendor_part_id='second', content='Samuel') - assert event == snapshot(PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Samuel'))) - assert manager.get_parts() == snapshot([TextPart(content='hello world'), TextPart(content='goodbye Samuel')]) + assert event == snapshot( + PartDeltaEvent( + index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye Samuel', part_kind='text')] + ) def test_handle_tool_call_deltas(): @@ -65,33 +89,36 @@ def test_handle_tool_call_deltas(): assert event == snapshot( PartStartEvent( index=0, - part=ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None), + part=ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), + event_kind='part_start', ) ) - assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None)]) + assert manager.get_parts() == snapshot( + [ToolCallPart(tool_name='tool', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] + ) event = manager.handle_tool_call_delta(vendor_part_id='first', tool_name='1', args=None, tool_call_id=None) assert event == snapshot( PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta='1', - args_delta=None, - tool_call_id=None, + tool_name_delta='1', args_delta=None, tool_call_id=None, part_delta_kind='tool_call' ), + event_kind='part_delta', ) ) - assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None)]) + assert manager.get_parts() == snapshot( + [ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] + ) event = manager.handle_tool_call_delta(vendor_part_id='first', tool_name=None, args='"value1"}', tool_call_id=None) assert event == snapshot( PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta=None, - args_delta='"value1"}', - tool_call_id=None, + tool_name_delta=None, args_delta='"value1"}', tool_call_id=None, part_delta_kind='tool_call' ), + event_kind='part_delta', ) ) assert manager.get_parts() == snapshot( @@ -100,6 +127,7 @@ def test_handle_tool_call_deltas(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, + part_kind='tool-call', ) ] ) @@ -116,6 +144,7 @@ def test_handle_tool_call_deltas_without_vendor_id(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, + part_kind='tool-call', ) ] ) @@ -126,8 +155,8 @@ def test_handle_tool_call_deltas_without_vendor_id(): manager.handle_tool_call_delta(vendor_part_id=None, tool_name='tool2', args='"value1"}', tool_call_id=None) assert manager.get_parts() == snapshot( [ - ToolCallPart(tool_name='tool2', args='{"arg1":', tool_call_id=None), - ToolCallPart(tool_name='tool2', args='"value1"}', tool_call_id=None), + ToolCallPart(tool_name='tool2', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool2', args='"value1"}', tool_call_id=None, part_kind='tool-call'), ] ) @@ -140,20 +169,23 @@ def test_handle_tool_call_part(): assert event == snapshot( PartStartEvent( index=0, - part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None), + part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), + event_kind='part_start', ) ) # Add a delta manager.handle_tool_call_delta(vendor_part_id='second', tool_name='tool1', args=None, tool_call_id=None) - assert manager.get_parts() == snapshot([ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None)]) + assert manager.get_parts() == snapshot( + [ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call')] + ) # Override it with handle_tool_call_part manager.handle_tool_call_part(vendor_part_id='second', tool_name='tool1', args='{}', tool_call_id=None) assert manager.get_parts() == snapshot( [ - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), ] ) @@ -162,10 +194,9 @@ def test_handle_tool_call_part(): PartDeltaEvent( index=0, delta=ToolCallPartDelta( - tool_name_delta=None, - args_delta='"value1"}', - tool_call_id=None, + tool_name_delta=None, args_delta='"value1"}', tool_call_id=None, part_delta_kind='tool_call' ), + event_kind='part_delta', ) ) assert manager.get_parts() == snapshot( @@ -174,8 +205,9 @@ def test_handle_tool_call_part(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, + part_kind='tool-call', ), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), ] ) @@ -184,7 +216,8 @@ def test_handle_tool_call_part(): assert event == snapshot( PartStartEvent( index=2, - part=ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), + part=ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), + event_kind='part_start', ) ) assert manager.get_parts() == snapshot( @@ -193,9 +226,10 @@ def test_handle_tool_call_part(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id=None, + part_kind='tool-call', ), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), - ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), + ToolCallPart(tool_name='tool1', args='{}', tool_call_id=None, part_kind='tool-call'), ] ) @@ -206,8 +240,10 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non manager = ModelResponsePartsManager() event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ') - assert event == snapshot(PartStartEvent(index=0, part=TextPart(content='hello '))) - assert manager.get_parts() == snapshot([TextPart(content='hello ')]) + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) event = manager.handle_tool_call_delta( vendor_part_id=tool_vendor_part_id, tool_name='tool1', args='{"arg1":', tool_call_id='abc' @@ -215,7 +251,8 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non assert event == snapshot( PartStartEvent( index=1, - part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), + part=ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), + event_kind='part_start', ) ) @@ -224,22 +261,27 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non assert event == snapshot( PartStartEvent( index=2, - part=TextPart(content='world'), + part=TextPart(content='world', part_kind='text'), + event_kind='part_start', ) ) assert manager.get_parts() == snapshot( [ - TextPart(content='hello '), - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), - TextPart(content='world'), + TextPart(content='hello ', part_kind='text'), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), + TextPart(content='world', part_kind='text'), ] ) else: - assert event == snapshot(PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='world'))) + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) assert manager.get_parts() == snapshot( [ - TextPart(content='hello world'), - ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc'), + TextPart(content='hello world', part_kind='text'), + ToolCallPart(tool_name='tool1', args='{"arg1":', tool_call_id='abc', part_kind='tool-call'), ] ) @@ -271,6 +313,7 @@ def test_tool_call_id_delta(): tool_name='tool1', args='{"arg1":', tool_call_id=None, + part_kind='tool-call', ) ] ) @@ -282,6 +325,7 @@ def test_tool_call_id_delta(): tool_name='tool1', args='{"arg1":"value1"}', tool_call_id='id2', + part_kind='tool-call', ) ] ) @@ -302,6 +346,7 @@ def test_tool_call_id_delta_failure(apply_to_delta: bool): tool_name='tool1', args='{"arg1":', tool_call_id='id1', + part_kind='tool-call', ) ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5d07bf4cf..f95be4c13 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -592,8 +592,10 @@ def another_tool(y: int) -> int: # pragma: no cover UserPromptPart( content='test early strategy with final ' 'result in middle', timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='user-prompt', ) ], + kind='request', ), ModelResponse( parts=[ @@ -601,25 +603,30 @@ def another_tool(y: int) -> int: # pragma: no cover tool_name='regular_tool', args='{"x": 1}', tool_call_id=None, + part_kind='tool-call', ), ToolCallPart( tool_name='final_result', args='{"value": "final"}', tool_call_id=None, + part_kind='tool-call', ), ToolCallPart( tool_name='another_tool', args='{"y": 2}', tool_call_id=None, + part_kind='tool-call', ), ToolCallPart( tool_name='unknown_tool', args='{"value": "???"}', tool_call_id=None, + part_kind='tool-call', ), ], model_name='function:sf', timestamp=IsNow(tz=datetime.timezone.utc), + kind='response', ), ModelRequest( parts=[ @@ -628,18 +635,21 @@ def another_tool(y: int) -> int: # pragma: no cover content='Tool not executed - a final ' 'result was already processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final ' 'result was already processed.', tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', ), RetryPromptPart( content='Unknown tool name: ' @@ -649,8 +659,10 @@ def another_tool(y: int) -> int: # pragma: no cover tool_name=None, tool_call_id=None, timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='retry-prompt', ), ], + kind='request', ), ] ) From 8d2c74ea5f22af1429139f1805ffa87129460b5d Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 17 Feb 2025 21:32:41 -0700 Subject: [PATCH 17/28] Add more docstrings --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 10 ++++------ pydantic_ai_slim/pydantic_ai/_utils.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 7 +++++-- pydantic_ai_slim/pydantic_ai/models/__init__.py | 8 ++++---- pydantic_graph/pydantic_graph/graph.py | 7 ++++--- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index d1185246c..52c970ddb 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -323,7 +323,6 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007 - """TODO: Docstring?""" async with self.stream(ctx): pass @@ -334,7 +333,7 @@ async def run( async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: - """TODO: Docstring.""" + """Process the model response and yield events for the start and end of each function tool call.""" with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: stream = self._run_stream(ctx) yield stream @@ -483,7 +482,7 @@ async def _handle_text_response( def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: - """TODO: Docstring.""" + """Build a `RunContext` object from the current agent graph run context.""" return RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, @@ -607,7 +606,7 @@ async def _validate_result( def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: - """TODO: Docstring.""" + """Check if the result schema allows text results.""" return result_schema is None or result_schema.allow_text_result @@ -662,8 +661,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]: - """TODO: Docstring.""" - # We'll define the known node classes: + """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index b2e01b1af..667727306 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -85,7 +85,7 @@ async def group_by_temporal( ) -> AsyncIterator[AsyncIterable[list[T]]]: """Group items from an async iterable into lists based on time interval between them. - Effectively debouncing the iterator. + Effectively, this debounces the iterator. This returns a context manager usable as an iterator so any pending tasks can be cancelled if an error occurs during iteration. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index fb6922098..18df177aa 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1161,9 +1161,12 @@ def _prepare_result_schema( @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): - """A stateful, iterable run of an agent. + """A stateful, (async) iterable run of an agent. - TODO: Add API documentation here. + You can use `async for` to iterate over the nodes without any modification to the run, or you can use the `next()` + method to iteratively drive the run with the ability to manipulate the node at any point before continuing on. + + TODO: Add API documentation here. Docstring. """ _graph_run: GraphRun[ diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index edc9315c7..c25f39b7e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -263,19 +263,19 @@ def timestamp(self) -> datetime: raise NotImplementedError() async def stream_events(self) -> AsyncIterator[ModelResponseStreamEvent]: - """TODO: Docstring.""" + """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" return self.__aiter__() async def stream_debounced_events( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[list[ModelResponseStreamEvent]]: - """TODO: Docstring.""" + """Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" async with _utils.group_by_temporal(self, debounce_by) as group_iter: async for items in group_iter: yield items async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: - """TODO: Docstring.""" + """Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s.""" async def _stream_structured_ungrouped() -> AsyncIterator[None]: # yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting @@ -287,7 +287,7 @@ async def _stream_structured_ungrouped() -> AsyncIterator[None]: yield self.get() # current state of the response async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: - """TODO: Docstring.""" + """Stream the response as an async iterable of text.""" # Define a "merged" version of the iterator that will yield items that have already been retrieved # and items that we receive while streaming. We define a dedicated async iterator for this so we can diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index db558cd3a..59bbe7f90 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -585,11 +585,12 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: class GraphRun(Generic[StateT, DepsT, RunEndT]): - """A stateful, iterable run of a graph. + """A stateful, (async) iterable run of a graph. - After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. + You can use `async for` to iterate over the nodes without any modification to the run, or you can use the `next()` + method to iteratively drive the run with the ability to manipulate the node at any point before continuing on. - TODO: this requires some heavy weight API documentation. Docstring + TODO: Add API documentation here. Docstring. """ def __init__( From 4bb67a550a2c384ff0e32187bf4416fa5ad8b7fe Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 00:18:45 -0700 Subject: [PATCH 18/28] Add more docs --- Makefile | 2 +- docs/agents.md | 132 ++++++++++++++++++++++-- docs/graph.md | 136 +++++++++++++++++++++---- pydantic_ai_slim/pydantic_ai/agent.py | 120 +++++++++++++++++++--- pydantic_graph/pydantic_graph/graph.py | 131 +++++++++++++++++------- 5 files changed, 444 insertions(+), 77 deletions(-) diff --git a/Makefile b/Makefile index 7b12e2a9f..4c6c75a4e 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ testcov: test ## Run tests and generate a coverage report .PHONY: update-examples update-examples: ## Update documentation examples - uv run -m pytest --update-examples + uv run -m pytest --update-examples tests/test_examples.py # `--no-strict` so you can build the docs without insiders packages .PHONY: docs diff --git a/docs/agents.md b/docs/agents.md index 19da9f3a2..8294a4dd6 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -62,13 +62,14 @@ print(result.data) ## Running Agents -There are three ways to run an agent: +There are four ways to run an agent: -1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response -2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`) -3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable +1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response. +2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). +3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. +4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's graph. -Here's a simple example demonstrating all three: +Here's a simple example demonstrating the first three: ```python {title="run_agent.py"} from pydantic_ai import Agent @@ -93,6 +94,125 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci You can also pass messages from previous runs to continue a conversation or provide context, as described in [Messages and Chat History](message-history.md). +--- + +### Iterating Over an Agent's Graph + +In more advanced scenarios, you may want to inspect or manipulate the agent's workflow as it runs. For example, you may want to collect data at each step of the run or manually decide how to proceed based on the node returned. In these situations, you can use the [`Agent.iter`][pydantic_ai.Agent.iter] method, a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun]. + +#### `async for` iteration + +Here's an example of using `async for` with `iter` to record each node the agent executes: + +```python +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4o') + + +async def main(): + nodes = [] + # Begin an AgentRun, which is an async-iterable over the nodes of the agent's graph + with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + # Each node represents a step in the agent's execution + nodes.append(node) + print(nodes) + """ + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + """ + print(agent_run.final_result.data) + #> Paris +``` + +- The `AgentRun` is an async iterator that yields each node (`BaseNode` or `End`) in the flow. +- The run ends when an `End` node is returned. + +#### Using `.next(...)` manually + +You can also drive the iteration manually by passing the node you want to run next to the `AgentRun.next(...)` method. This allows you to inspect or modify the node before it executes or skip nodes based on your own logic: + +```python +from pydantic_ai import Agent +from pydantic_graph import End + +agent = Agent('openai:gpt-4o') + + +async def main(): + with agent.iter('What is the capital of France?') as agent_run: + # You can get the first node by calling __anext__ once + node = await agent_run.__anext__() + + # Keep track of nodes here + all_nodes = [node] + + # Drive the iteration manually + while not isinstance(node, End): + # You could inspect or mutate the node here as needed + node = await agent_run.next(node) + all_nodes.append(node) + + print(all_nodes) + """ + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + """ +``` + +- When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. +- The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. + +#### Accessing usage and the final result + +You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`](pydantic_ai.agent.AgentRun) object via `agent_run.usage()`. This method returns a [`Usage`](pydantic_ai.usage.Usage) object containing the usage data. + +Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`](pydantic_ai.agent.AgentRunResult) object containing the final output (and related metadata). + +--- ### Additional Configuration @@ -177,7 +297,7 @@ except UsageLimitExceeded as e: 2. This run will error after 3 requests, preventing the infinite tool calling. !!! note - This is especially relevant if you're registered a lot of tools, `request_limit` can be used to prevent the model from choosing to make too many of these calls. + This is especially relevant if you've registered many tools. The `request_limit` can be used to prevent the model from calling them in a loop too many times. #### Model (Run) Settings diff --git a/docs/graph.md b/docs/graph.md index 9db4cbffe..3ef996b85 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -16,12 +16,12 @@ Graphs and finite state machines (FSMs) are a powerful abstraction to model, exe Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. -While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. -`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and type hints. It is not designed to be as beginner-friendly as PydanticAI. !!! note "Very Early beta" - Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in a very early beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -33,7 +33,7 @@ pip/uv-add pydantic-graph ## Graph Types -`pydantic-graph` made up of a few key components: +`pydantic-graph` is made up of a few key components: ### GraphRunContext @@ -167,7 +167,7 @@ print([item.data_snapshot() for item in graph_run.history]) 1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 3. The graph is created with a sequence of nodes. -4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. +4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync]. The initial node is `DivisibleBy5(4)`. Because the graph doesn't use external state or deps, we don't pass `state` or `deps`. _(This example is complete, it can be run "as is" with Python 3.10+)_ @@ -295,17 +295,17 @@ async def main(): 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. 4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. -5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass] with one field `amount`. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. 8. In the `Purchase` node, look up the price of the product if the user entered a valid product. 9. If the user did enter a valid product, set the product in the state so we don't revisit `SelectProduct`. 10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. -11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. +11. If the balance is insufficient, go to `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. -13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but it can affect how [diagrams](#mermaid-diagrams) are displayed. 14. Initialize the state. This will be passed to the graph run and mutated as the graph runs. -15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] that provides the final data and a history of the run. 16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important as it is used to determine the outgoing edges of the node. This information in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime to detect misbehavior as soon as possible. 17. The return type of `CoinsInserted`'s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. 18. Unlike other nodes, `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set. In this case it's `None` since the graph run return type is `None`. @@ -643,11 +643,107 @@ stateDiagram-v2 Reprimand --> Ask ``` -You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +You maybe have noticed that although this example transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). + +## Iterating Over a Graph + +### Using `Graph.iter` for `async for` iteration + +Sometimes you want direct control or insight into each node as the graph executes. The easiest way to do that is with the [`Graph.iter`][pydantic_graph.graph.Graph.iter] method, which returns a **context manager** that yields a [`GraphRun`][pydantic_graph.graph.GraphRun] object. The `GraphRun` is an async-iterable over the nodes of your graph, allowing you to record or modify them as they execute. + +Here's an example: + +```python {title="count_down.py" noqa="I001" py="3.10"} +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from pydantic_graph import Graph, BaseNode, End, GraphRunContext + + +@dataclass +class CountDownState: + counter: int + + +@dataclass +class CountDown(BaseNode[CountDownState]): + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: + if ctx.state.counter <= 0: + return End(ctx.state.counter) + ctx.state.counter -= 1 + return CountDown() + + +count_down_graph = Graph(nodes=[CountDown]) + + +async def main(): + state = CountDownState(counter=3) + with count_down_graph.iter(CountDown(), state=state) as run: + async for node in run: + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: End(data=0) + print('Final result:', run.final_result.result) + #> Final result: 0 + print('History snapshots:', [step.data_snapshot() for step in run.history]) + """ + History snapshots: + [CountDown(), CountDown(), CountDown(), CountDown(), End(data=0)] + """ +``` + +- `Graph.iter(...)` returns a `GraphRun`. +- You can `async for node in run` to step through each node as it is executed. +- Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). + +### Using `GraphRun.next(node)` manually + +Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydantic_graph.graph.GraphRun.next] method, which allows you to pass in whichever node you want to run next. You can modify or selectively skip nodes this way. + +Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: + +```python {title="count_down_next.py" noqa="I001" py="3.10"} +from pydantic_graph import End +from count_down import CountDown, CountDownState, count_down_graph + + +async def main(): + state = CountDownState(counter=5) + with count_down_graph.iter(CountDown(), state=state) as run: + node = await run.__anext__() # first node + while not isinstance(node, End): + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + if state.counter == 2: + # Let's forcibly end the run early + break + # run the node we got, which might produce a new node or End + node = await run.next(node) + + # Because the run ended early, we have no final result: + assert run.final_result is None + + # The run still has partial history though + for step in run.history: + print('History Step:', step.data_snapshot()) + #> History Step: CountDown() + #> History Step: CountDown() + #> History Step: CountDown() +``` + +- We grab the first node with `await run.__anext__()`. +- At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). +- If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). +- The run's history is still populated with the steps we executed so far. ## Dependency Injection -As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] fields. +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] field. As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): @@ -667,12 +763,12 @@ class GraphDeps: @dataclass -class DivisibleBy5(BaseNode[None, None, int]): +class DivisibleBy5(BaseNode[None, GraphDeps, int]): foo: int async def run( self, - ctx: GraphRunContext, + ctx: GraphRunContext[None, GraphDeps], ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) @@ -681,10 +777,10 @@ class DivisibleBy5(BaseNode[None, None, int]): @dataclass -class Increment(BaseNode): +class Increment(BaseNode[None, GraphDeps]): foo: int - async def run(self, ctx: GraphRunContext) -> DivisibleBy5: + async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5: loop = asyncio.get_running_loop() compute_result = await loop.run_in_executor( ctx.deps.executor, @@ -702,11 +798,11 @@ fives_graph = Graph(nodes=[DivisibleBy5, Increment]) async def main(): with ProcessPoolExecutor() as executor: deps = GraphDeps(executor) - result, history = await fives_graph.run(DivisibleBy5(3), deps=deps) - print(result) + graph_run = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(graph_run.result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary - print([item.data_snapshot() for item in history]) + print([item.data_snapshot() for item in graph_run.history]) """ [ DivisibleBy5(foo=3), @@ -780,7 +876,7 @@ question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) _(This example is not complete and cannot be run directly)_ -Would generate and image that looks like this: +This would generate an image that looks like this: ```mermaid --- @@ -810,7 +906,7 @@ You can specify the direction of the state diagram using one of the following va - `'RL'`: Right to left, the diagram flows horizontally from right to left. - `'BT'`: Bottom to top, the diagram flows vertically from bottom to top. -Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB) +Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB): ```py {title="vending_machine_diagram.py" py="3.10"} from vending_machine import InsertCoin, vending_machine_graph diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 18df177aa..156c4370e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -312,7 +312,7 @@ def iter( """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an - `AgentRun` object. The `AgentRun` can be used to (async) iterate over the nodes of the graph as they are + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the stream of events coming from the execution of tools. @@ -1161,12 +1161,58 @@ def _prepare_result_schema( @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): - """A stateful, (async) iterable run of an agent. + """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent]. - You can use `async for` to iterate over the nodes without any modification to the run, or you can use the `next()` - method to iteratively drive the run with the ability to manipulate the node at any point before continuing on. + You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`. - TODO: Add API documentation here. Docstring. + Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an + [`End`][pydantic_graph.nodes.End] is reached, the run finishes and + [`final_result`][pydantic_ai.agent.AgentRun.final_result] becomes available. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + # Iterate through the run, recording each node along the way: + with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + ''' + print(agent_run.final_result.data) + #> Paris + ``` + + You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for + more granular control. """ _graph_run: GraphRun[ @@ -1182,7 +1228,11 @@ def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.Grap @property def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT] | None: - """The final result of the agent run.""" + """The final result of the run if it has ended, otherwise `None`. + + Once the run returns an [`End`][pydantic_graph.nodes.End] node, `final_result` is populated + with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. + """ graph_run_result = self._graph_run.final_result if graph_run_result is None: return None @@ -1198,6 +1248,7 @@ def __aiter__( ] | End[FinalResult[ResultDataT]] ]: + """Provide async-iteration over the nodes in the agent run.""" return self async def __anext__( @@ -1210,7 +1261,7 @@ async def __anext__( ] | End[FinalResult[ResultDataT]] ): - """Use the last returned node as the input to `Graph.next`.""" + """Advance to the next node automatically based on the last returned node.""" return await self._graph_run.__anext__() async def next( @@ -1228,14 +1279,59 @@ async def next( ] | End[FinalResult[ResultDataT]] ): - # TODO: It would be nice to expose a synchronous interface for this, to be able to - # synchronously iterate over the agent graph. I don't think this would be hard to do, - # but I'm having a hard time coming up with an API that fits nicely along side the current `run_sync`. - # The use of `await` provides an easy way to signal that you just want the result, but it's less - # clear to me what the analogous thing should be for synchronous code. + """Manually drive the agent run by passing in the node you want to run next. + + This lets you inspect or mutate the node before continuing execution, or skip certain nodes + under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End] + node. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_graph import End + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + with agent.iter('What is the capital of France?') as agent_run: + # The first node can be retrieved via __anext__(), or you might already have it. + next_node = await agent_run.__anext__() + while not isinstance(next_node, End): + next_node = await agent_run.next(next_node) + nodes.append(next_node) + # Once `next_node` is an End, we've finished: + print(nodes) + ''' + [ + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + ''' + print('Final result:', agent_run.final_result.data) + #> Final result: Paris + ``` + + Args: + node: The node to run next in the graph. + + Returns: + The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if + the run has completed. + """ + # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it + # on this class or IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. return await self._graph_run.next(node) def usage(self) -> _usage.Usage: + """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage def __repr__(self): diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 59bbe7f90..cd6c4f937 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -193,14 +193,17 @@ def iter( ) -> Iterator[GraphRun[StateT, DepsT, T]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. - This method returns a `GraphRun` object which can be used to (async) iterate over the nodes of this `Graph` as + This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as they are executed. This is the API to use if you want to record or interact with the nodes as the graph execution unfolds. + The `GraphRun` can also be used to manually drive the graph execution by calling + [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. + The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once it has completed. - For more details, see the documentation of `GraphRun`. + For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. Args: start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, @@ -212,33 +215,6 @@ def iter( Yields: A GraphRun that can be async iterated over to drive the graph to completion. - - Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: - - ```py {title="run_never_42.py" noqa="I001" py="3.10"} - from never_42 import Increment, MyState, never_42_graph - - async def main(): - state = MyState(1) - nodes = [] - with never_42_graph.iter(Increment(), state=state) as graph_run: - async for node in graph_run: - nodes.append(node) - print(nodes) - #> [Check42(), End(data=2)] - print(state) - #> MyState(number=2) - - state = MyState(41) - nodes = [] - with never_42_graph.iter(Increment(), state=state) as graph_run: - async for node in graph_run: - nodes.append(node) - print(nodes) - #> [Check42(), Increment(), Check42(), End(data=43)] - print(state) - #> MyState(number=43) - ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) @@ -408,7 +384,7 @@ def mermaid_code( Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="never_42.py" py="3.10"} + ```py {title="mermaid_never_42.py" py="3.10"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) @@ -585,12 +561,44 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: class GraphRun(Generic[StateT, DepsT, RunEndT]): - """A stateful, (async) iterable run of a graph. - - You can use `async for` to iterate over the nodes without any modification to the run, or you can use the `next()` - method to iteratively drive the run with the ability to manipulate the node at any point before continuing on. + """A stateful, async-iterable run of a [`Graph`][pydantic_ai.graph.Graph]. + + You typically get a `GraphRun` instance from calling + `with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate + through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`. + + Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: + ```py {title="iter_never_42.py" noqa="I001" py="3.10"} + from copy import deepcopy + from never_42 import Increment, MyState, never_42_graph + + async def main(): + state = MyState(1) + node_states = [] + with never_42_graph.iter(Increment(), state=state) as graph_run: + async for node in graph_run: + node_states.append((node, deepcopy(graph_run.state))) + print(node_states) + #> [(Check42(), MyState(number=2)), (End(data=2), MyState(number=2))] + + state = MyState(41) + node_states = [] + with never_42_graph.iter(Increment(), state=state) as graph_run: + async for node in graph_run: + node_states.append((node, deepcopy(graph_run.state))) + print(node_states) + ''' + [ + (Check42(), MyState(number=42)), + (Increment(), MyState(number=42)), + (Check42(), MyState(number=43)), + (End(data=43), MyState(number=43)), + ] + ''' + ``` - TODO: Add API documentation here. Docstring. + See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually + drive the graph run. """ def __init__( @@ -604,6 +612,22 @@ def __init__( auto_instrument: bool, span: LogfireSpan | None = None, ): + """Create a new run for a given graph, starting at the specified node. + + Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly. + + Args: + graph: The [`Graph`][pydantic_graph.graph.Graph] to run. + start_node: The node where execution will begin. + history: A list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects that describe + each step of the run. Usually starts empty; can be populated if resuming. + state: A shared state object or primitive (like a counter, dataclass, etc.) that is available + to all nodes via `ctx.state`. + deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, + configuration, or logging clients. + auto_instrument: Whether to automatically create instrumentation spans during the run. + span: An optional existing Logfire span to nest node-level spans under (advanced usage). + """ self.graph = graph self.history = history self.state = state @@ -615,7 +639,7 @@ def __init__( @property def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: - """The final result of the agent run if the run is completed, otherwise `None`.""" + """The final result of the graph run if the run is completed, otherwise `None`.""" if not isinstance(self._next_node, End): return None # The GraphRun has not finished running return GraphRunResult( @@ -625,9 +649,40 @@ def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: async def next( self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] ) -> BaseNode[StateT, DepsT, T] | End[T]: - """Note: this method behaves very similarly to an async generator's `asend` method.""" - # TODO: replace the End[T] return with a RunResult[T] type which includes extra data. + """Manually drive the graph run by passing in the node you want to run next. + + This lets you inspect or mutate the node before continuing execution, or skip certain nodes + under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node. + + Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]: + ```py {title="next_never_42.py" noqa="I001" py="3.10"} + from copy import deepcopy + from pydantic_graph import End + from never_42 import Increment, MyState, never_42_graph + + async def main(): + node_states = [] + state = MyState(48) + with never_42_graph.iter(Increment(), state=state) as graph_run: + next_node = await graph_run.__anext__() + while not isinstance(next_node, End): + if graph_run.state.number == 50: + graph_run.state.number = 42 + node_states.append((next_node, deepcopy(graph_run.state))) + next_node = await graph_run.next(next_node) + node_states.append((next_node, deepcopy(graph_run.state))) + + print(node_states) + #> [(Check42(), MyState(number=49)), (End(data=49), MyState(number=49))] + ``` + Args: + node: The node to run next in the graph. + + Returns: + The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if + the run has completed. + """ history = self.history state = self.state deps = self.deps From a6e6445fe0c6818e7d295b0ab65bc323bb81ac4f Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:14:57 -0700 Subject: [PATCH 19/28] Fix various docs references --- docs/agents.md | 8 +++---- docs/api/agent.md | 2 ++ docs/api/result.md | 4 +++- docs/message-history.md | 8 +++---- docs/multi-agent-applications.md | 4 ++-- docs/results.md | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 31 +++++++++++++++++++++++--- pydantic_ai_slim/pydantic_ai/result.py | 4 ++-- pydantic_graph/pydantic_graph/graph.py | 2 +- 9 files changed, 47 insertions(+), 18 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 8294a4dd6..5c7b1cbf3 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -64,8 +64,8 @@ print(result.data) There are four ways to run an agent: -1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response. -2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). +1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. +2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). 3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. 4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's graph. @@ -208,9 +208,9 @@ async def main(): #### Accessing usage and the final result -You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`](pydantic_ai.agent.AgentRun) object via `agent_run.usage()`. This method returns a [`Usage`](pydantic_ai.usage.Usage) object containing the usage data. +You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`Usage`][pydantic_ai.usage.Usage] object containing the usage data. -Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`](pydantic_ai.agent.AgentRunResult) object containing the final output (and related metadata). +Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] object containing the final output (and related metadata). --- diff --git a/docs/api/agent.md b/docs/api/agent.md index 890c418ee..134d4dd24 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -4,6 +4,8 @@ options: members: - Agent + - AgentRun + - AgentRunResult - EndStrategy - RunResultData - capture_run_messages diff --git a/docs/api/result.md b/docs/api/result.md index c22a52e24..d07778e95 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -2,4 +2,6 @@ ::: pydantic_ai.result options: - inherited_members: true + inherited_members: true + members: + - StreamedRunResult diff --git a/docs/message-history.md b/docs/message-history.md index d538112f8..1fad6f54c 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -6,12 +6,12 @@ PydanticAI provides access to messages exchanged during an agent run. These mess After running an agent, you can access the messages exchanged during that run from the `result` object. -Both [`RunResult`][pydantic_ai.result.RunResult] +Both [`RunResult`][pydantic_ai.agent.AgentRunResult] (returned by [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync]) and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.Agent.run_stream]) have the following methods: -* [`all_messages()`][pydantic_ai.result.RunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.result.RunResult.all_messages_json]. -* [`new_messages()`][pydantic_ai.result.RunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.result.RunResult.new_messages_json]. +* [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. +* [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. !!! info "StreamedRunResult and complete messages" On [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], the messages returned from these methods will only include the final result message once the stream has finished. @@ -25,7 +25,7 @@ and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`A **Note:** The final result message will NOT be added to result messages if you use [`.stream_text(delta=True)`][pydantic_ai.result.StreamedRunResult.stream_text] since in this case the result content is never built as one string. -Example of accessing methods on a [`RunResult`][pydantic_ai.result.RunResult] : +Example of accessing methods on a [`RunResult`][pydantic_ai.agent.AgentRunResult] : ```python {title="run_result_messages.py" hl_lines="10 28"} from pydantic_ai import Agent diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 94913d1c4..002dd3c55 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -18,7 +18,7 @@ Since agents are stateless and designed to be global, you do not need to include You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.Agent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. !!! note "Multiple models" - Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.result.RunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. + Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. ```python {title="agent_delegation_simple.py"} from pydantic_ai import Agent, RunContext @@ -62,7 +62,7 @@ Usage( 1. The "parent" or controlling agent. 2. The "delegate" agent, which is called from within a tool of the parent agent. 3. Call the delegate agent from within a tool of the parent agent. -4. Pass the usage from the parent agent to the delegate agent so the final [`result.usage()`][pydantic_ai.result.RunResult.usage] includes the usage from both agents. +4. Pass the usage from the parent agent to the delegate agent so the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] includes the usage from both agents. 5. Since the function returns `#!python list[str]`, and the `result_type` of `joke_generation_agent` is also `#!python list[str]`, we can simply return `#!python r.data` from the tool. _(This example is complete, it can be run "as is")_ diff --git a/docs/results.md b/docs/results.md index e4e8a8c63..678048014 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,5 +1,5 @@ Results are the final values returned from [running an agent](agents.md#running-agents). -The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) +The result values are wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 156c4370e..2e8ad53e5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -41,7 +41,7 @@ ToolPrepareFunc, ) -__all__ = 'Agent', 'capture_run_messages', 'EndStrategy' +__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy' _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -1377,13 +1377,24 @@ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMe raise LookupError(f'No tool call found with tool name {self._result.tool_name!r}.') def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return the history of _messages. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of messages. + """ if result_tool_return_content is not None: return self._set_result_tool_return(result_tool_return_content) else: return self.graph_run_result.state.message_history def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. @@ -1403,10 +1414,23 @@ def _new_message_index(self) -> int: return self.graph_run_result.deps.new_message_index def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return new messages associated with this run. + + Messages from older runs are excluded. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of new messages. + """ return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. + """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. @@ -1422,4 +1446,5 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> ) def usage(self) -> _usage.Usage: + """Return the usage of the whole run.""" return self.graph_run_result.state.usage diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 42135e4cc..41cc20b80 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -98,7 +98,7 @@ def all_messages(self, *, result_tool_return_content: str | None = None) -> list return self._all_messages def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. @@ -130,7 +130,7 @@ def new_messages(self, *, result_tool_return_content: str | None = None) -> list return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. + """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index cd6c4f937..8e172e127 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -561,7 +561,7 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: class GraphRun(Generic[StateT, DepsT, RunEndT]): - """A stateful, async-iterable run of a [`Graph`][pydantic_ai.graph.Graph]. + """A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph]. You typically get a `GraphRun` instance from calling `with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate From 007d8ca9cc5949971d664b3b5eb2432a0953a253 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:17:16 -0700 Subject: [PATCH 20/28] Fix final docs references --- docs/api/agent.md | 2 +- docs/api/result.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/api/agent.md b/docs/api/agent.md index 134d4dd24..b26cfb58e 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -7,5 +7,5 @@ - AgentRun - AgentRunResult - EndStrategy - - RunResultData + - RunResultDataT - capture_run_messages diff --git a/docs/api/result.md b/docs/api/result.md index d07778e95..8e6cef79e 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -4,4 +4,5 @@ options: inherited_members: true members: + - ResultDataT - StreamedRunResult From 6d532c1f5fec55e890f90c4a9ca2e80d71890d58 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 13:07:06 -0700 Subject: [PATCH 21/28] Address some feedback --- docs/agents.md | 24 +++++--- docs/graph.md | 5 +- pydantic_ai_slim/pydantic_ai/agent.py | 40 +++++++++++-- pydantic_graph/pydantic_graph/graph.py | 77 +++++++++++++++++++------- 4 files changed, 109 insertions(+), 37 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 5c7b1cbf3..53f9861c3 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -67,7 +67,7 @@ There are four ways to run an agent: 1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. 2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). 3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. -4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's graph. +4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph]. Here's a simple example demonstrating the first three: @@ -94,17 +94,18 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci You can also pass messages from previous runs to continue a conversation or provide context, as described in [Messages and Chat History](message-history.md). ---- ### Iterating Over an Agent's Graph -In more advanced scenarios, you may want to inspect or manipulate the agent's workflow as it runs. For example, you may want to collect data at each step of the run or manually decide how to proceed based on the node returned. In these situations, you can use the [`Agent.iter`][pydantic_ai.Agent.iter] method, a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun]. +Under the hood, each `Agent` in PydanticAI uses [pydantic-graph][pydantic_graph] to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. + +In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — PydanticAI exposes the lower-level iteration process via [`Agent.iter`][pydantic_ai.Agent.iter]. This method returns an [`AgentRun`][pydantic_ai.agent.AgentRun], which you can async-iterate over, or manually drive node-by-node via the [`next`][pydantic_ai.agent.AgentRun.next] method. Once the agent's graph returns an [`End`][pydantic_graph.nodes.End], you have the final result along with a detailed history of all steps. #### `async for` iteration Here's an example of using `async for` with `iter` to record each node the agent executes: -```python +```python {title="agent_iter_async_for.py"} from pydantic_ai import Agent agent = Agent('openai:gpt-4o') @@ -152,9 +153,9 @@ async def main(): #### Using `.next(...)` manually -You can also drive the iteration manually by passing the node you want to run next to the `AgentRun.next(...)` method. This allows you to inspect or modify the node before it executes or skip nodes based on your own logic: +You can also drive the iteration manually by passing the node you want to run next to the `AgentRun.next(...)` method. This allows you to inspect or modify the node before it executes or skip nodes based on your own logic, and to catch errors in `next()` more easily: -```python +```python {title="agent_iter_next.py"} from pydantic_ai import Agent from pydantic_graph import End @@ -163,8 +164,7 @@ agent = Agent('openai:gpt-4o') async def main(): with agent.iter('What is the capital of France?') as agent_run: - # You can get the first node by calling __anext__ once - node = await agent_run.__anext__() + node = agent_run.next_node # start with the first node # Keep track of nodes here all_nodes = [node] @@ -178,6 +178,12 @@ async def main(): print(all_nodes) """ [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), ModelRequestNode( request=ModelRequest( parts=[ @@ -561,7 +567,7 @@ If models behave unexpectedly (e.g., the retry limit is exceeded, or their API r In these cases, [`capture_run_messages`][pydantic_ai.capture_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue. -```python +```python {title="agent_model_errors.py"} from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, capture_run_messages agent = Agent('openai:gpt-4o') diff --git a/docs/graph.md b/docs/graph.md index 3ef996b85..85312d30e 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -713,12 +713,13 @@ from count_down import CountDown, CountDownState, count_down_graph async def main(): state = CountDownState(counter=5) with count_down_graph.iter(CountDown(), state=state) as run: - node = await run.__anext__() # first node + node = run.next_node # start with the first node while not isinstance(node, End): print('Node:', node) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() + #> Node: CountDown() if state.counter == 2: # Let's forcibly end the run early break @@ -736,7 +737,7 @@ async def main(): #> History Step: CountDown() ``` -- We grab the first node with `await run.__anext__()`. +- We grab the first node via `run.next_node`. - At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). - If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). - The run's history is still populated with the steps we executed so far. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2e8ad53e5..a3f38e851 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -629,8 +629,8 @@ async def main(): usage=usage, infer_name=False, ) as agent_run: - first_node = await agent_run.__anext__() - assert isinstance(first_node, _agent_graph.ModelRequestNode) # the first node should be a request node + first_node = agent_run.next_node # start with the first node + assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node) while True: if isinstance(node, _agent_graph.ModelRequestNode): @@ -1226,6 +1226,19 @@ def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.Grap self._graph_run.state, self._graph_run.deps ) + @property + def next_node( + self, + ) -> ( + BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]] + | End[FinalResult[ResultDataT]] + ): + """The next node that will be run in the agent graph. + + This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. + """ + return self._graph_run.next_node + @property def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT] | None: """The final result of the run if it has ended, otherwise `None`. @@ -1293,10 +1306,9 @@ async def next( agent = Agent('openai:gpt-4o') async def main(): - nodes = [] with agent.iter('What is the capital of France?') as agent_run: - # The first node can be retrieved via __anext__(), or you might already have it. - next_node = await agent_run.__anext__() + next_node = agent_run.next_node # start with the first node + nodes = [next_node] while not isinstance(next_node, End): next_node = await agent_run.next(next_node) nodes.append(next_node) @@ -1304,6 +1316,24 @@ async def main(): print(nodes) ''' [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), HandleResponseNode( model_response=ModelResponse( parts=[TextPart(content='Paris', part_kind='text')], diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 8e172e127..e08d59767 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -287,6 +287,12 @@ async def next( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + + if isinstance(node, End): + # While technically this is not compatible with the documented method signature, it's an easy mistake to + # make, and we should eagerly provide a more helpful error message than you'd get otherwise. + raise exceptions.GraphRuntimeError(f'Cannot call `next` with an `End` node: {node!r}.') + node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -574,27 +580,34 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]): async def main(): state = MyState(1) - node_states = [] with never_42_graph.iter(Increment(), state=state) as graph_run: + node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) - print(node_states) - #> [(Check42(), MyState(number=2)), (End(data=2), MyState(number=2))] + print(node_states) + ''' + [ + (Increment(), MyState(number=1)), + (Check42(), MyState(number=2)), + (End(data=2), MyState(number=2)), + ] + ''' state = MyState(41) - node_states = [] with never_42_graph.iter(Increment(), state=state) as graph_run: + node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) - print(node_states) - ''' - [ - (Check42(), MyState(number=42)), - (Increment(), MyState(number=42)), - (Check42(), MyState(number=43)), - (End(data=43), MyState(number=43)), - ] - ''' + print(node_states) + ''' + [ + (Increment(), MyState(number=41)), + (Check42(), MyState(number=42)), + (Increment(), MyState(number=42)), + (Check42(), MyState(number=43)), + (End(data=43), MyState(number=43)), + ] + ''' ``` See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually @@ -637,6 +650,14 @@ def __init__( self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node + @property + def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: + """The next node that will be run in the graph. + + This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. + """ + return self._next_node + @property def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: """The final result of the graph run if the run is completed, otherwise `None`.""" @@ -647,7 +668,7 @@ def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: ) async def next( - self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] + self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] | None = None ) -> BaseNode[StateT, DepsT, T] | End[T]: """Manually drive the graph run by passing in the node you want to run next. @@ -661,28 +682,42 @@ async def next( from never_42 import Increment, MyState, never_42_graph async def main(): - node_states = [] state = MyState(48) with never_42_graph.iter(Increment(), state=state) as graph_run: - next_node = await graph_run.__anext__() + next_node = graph_run.next_node # start with the first node + node_states = [(next_node, deepcopy(graph_run.state))] + while not isinstance(next_node, End): if graph_run.state.number == 50: graph_run.state.number = 42 - node_states.append((next_node, deepcopy(graph_run.state))) next_node = await graph_run.next(next_node) - node_states.append((next_node, deepcopy(graph_run.state))) + node_states.append((next_node, deepcopy(graph_run.state))) - print(node_states) - #> [(Check42(), MyState(number=49)), (End(data=49), MyState(number=49))] + print(node_states) + ''' + [ + (Increment(), MyState(number=48)), + (Check42(), MyState(number=49)), + (End(data=49), MyState(number=49)), + ] + ''' ``` Args: - node: The node to run next in the graph. + node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to + the `start_node` of the run and updated each time a new node is returned. Returns: The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if the run has completed. """ + if node is None: + if isinstance(self._next_node, End): + # Note: we could alternatively just return `self._next_node` here, but it's easier to start with an + # error and relax the behavior later, than vice versa. + raise exceptions.GraphRuntimeError('This graph run has already ended.') + node = self._next_node + history = self.history state = self.state deps = self.deps From 0745ba99bae84c0f35601b5c05cca46c769bf8f2 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 13:22:17 -0700 Subject: [PATCH 22/28] Update docs --- docs/agents.md | 18 +++++----- docs/graph.md | 51 +++++++++++++-------------- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 53f9861c3..3707806f2 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -164,16 +164,14 @@ agent = Agent('openai:gpt-4o') async def main(): with agent.iter('What is the capital of France?') as agent_run: - node = agent_run.next_node # start with the first node + node = agent_run.next_node # (1)! - # Keep track of nodes here all_nodes = [node] - # Drive the iteration manually - while not isinstance(node, End): - # You could inspect or mutate the node here as needed - node = await agent_run.next(node) - all_nodes.append(node) + # Drive the iteration manually: + while not isinstance(node, End): # (2)! + node = await agent_run.next(node) # (3)! + all_nodes.append(node) # (4)! print(all_nodes) """ @@ -209,8 +207,10 @@ async def main(): """ ``` -- When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. -- The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +1. We start by grabbing the first node that will be run in the agent's graph. +2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +3. When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. +4. You could also inspect or mutate the new `node` here as needed. #### Accessing usage and the final result diff --git a/docs/graph.md b/docs/graph.md index 85312d30e..513e3f480 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -679,14 +679,14 @@ count_down_graph = Graph(nodes=[CountDown]) async def main(): state = CountDownState(counter=3) - with count_down_graph.iter(CountDown(), state=state) as run: - async for node in run: + with count_down_graph.iter(CountDown(), state=state) as run: # (1)! + async for node in run: # (2)! print('Node:', node) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: End(data=0) - print('Final result:', run.final_result.result) + print('Final result:', run.final_result.result) # (3)! #> Final result: 0 print('History snapshots:', [step.data_snapshot() for step in run.history]) """ @@ -695,9 +695,9 @@ async def main(): """ ``` -- `Graph.iter(...)` returns a `GraphRun`. -- You can `async for node in run` to step through each node as it is executed. -- Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). +1. `Graph.iter(...)` returns a [`GraphRun`][pydantic_graph.graph.GraphRun]. +2. Here, we step through each node as it is executed. +3. Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). ### Using `GraphRun.next(node)` manually @@ -713,34 +713,33 @@ from count_down import CountDown, CountDownState, count_down_graph async def main(): state = CountDownState(counter=5) with count_down_graph.iter(CountDown(), state=state) as run: - node = run.next_node # start with the first node - while not isinstance(node, End): + node = run.next_node # (1)! + while not isinstance(node, End): # (2)! print('Node:', node) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() if state.counter == 2: - # Let's forcibly end the run early - break - # run the node we got, which might produce a new node or End - node = await run.next(node) - - # Because the run ended early, we have no final result: - assert run.final_result is None - - # The run still has partial history though - for step in run.history: - print('History Step:', step.data_snapshot()) - #> History Step: CountDown() - #> History Step: CountDown() - #> History Step: CountDown() + break # (3)! + node = await run.next(node) # (4)! + + print(run.final_result) # (5)! + #> None + + for step in run.history: # (6)! + print('History Step:', step.data_snapshot(), step.state) + #> History Step: CountDown() CountDownState(counter=4) + #> History Step: CountDown() CountDownState(counter=3) + #> History Step: CountDown() CountDownState(counter=2) ``` -- We grab the first node via `run.next_node`. -- At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). -- If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). -- The run's history is still populated with the steps we executed so far. +1. We start by grabbing the first node that will be run in the agent's graph. +2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +3. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). +4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). +5. Because the run was ended early, we have no final result: +6. The run's history is still populated with the steps we executed so far. ## Dependency Injection diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a3f38e851..e7894ab78 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1357,7 +1357,7 @@ async def main(): the run has completed. """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it - # on this class or IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. + # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. return await self._graph_run.next(node) def usage(self) -> _usage.Usage: From 8d86b3a52a9de4ba7b5d74f4fa5c26786601c223 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 13:46:24 -0700 Subject: [PATCH 23/28] Fix docs build --- docs/agents.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/agents.md b/docs/agents.md index 3707806f2..dd1c7187a 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -97,7 +97,7 @@ You can also pass messages from previous runs to continue a conversation or prov ### Iterating Over an Agent's Graph -Under the hood, each `Agent` in PydanticAI uses [pydantic-graph][pydantic_graph] to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. +Under the hood, each `Agent` in PydanticAI uses **pydantic-graph** to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — PydanticAI exposes the lower-level iteration process via [`Agent.iter`][pydantic_ai.Agent.iter]. This method returns an [`AgentRun`][pydantic_ai.agent.AgentRun], which you can async-iterate over, or manually drive node-by-node via the [`next`][pydantic_ai.agent.AgentRun.next] method. Once the agent's graph returns an [`End`][pydantic_graph.nodes.End], you have the final result along with a detailed history of all steps. From bdb5f77327731608ebc6eeec13831adae1b1e0ae Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 14:25:02 -0700 Subject: [PATCH 24/28] Make the graph_run_result private on AgentRunResult --- pydantic_ai_slim/pydantic_ai/agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e7894ab78..693bd9f39 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1375,13 +1375,13 @@ def __repr__(self): class AgentRunResult(Generic[AgentDepsT, ResultDataT]): """The final result of an agent run.""" - graph_run_result: GraphRunResult[ + _graph_run_result: GraphRunResult[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] ] @property def _result(self) -> FinalResult[ResultDataT]: - return self.graph_run_result.result + return self._graph_run_result.result @property def data(self) -> ResultDataT: @@ -1398,7 +1398,7 @@ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMe """ if not self._result.tool_name: raise ValueError('Cannot set result tool return content when the return type is `str`.') - messages = deepcopy(self.graph_run_result.state.message_history) + messages = deepcopy(self._graph_run_result.state.message_history) last_message = messages[-1] for part in last_message.parts: if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result.tool_name: @@ -1421,7 +1421,7 @@ def all_messages(self, *, result_tool_return_content: str | None = None) -> list if result_tool_return_content is not None: return self._set_result_tool_return(result_tool_return_content) else: - return self.graph_run_result.state.message_history + return self._graph_run_result.state.message_history def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. @@ -1441,7 +1441,7 @@ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> @property def _new_message_index(self) -> int: - return self.graph_run_result.deps.new_message_index + return self._graph_run_result.deps.new_message_index def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return new messages associated with this run. @@ -1477,4 +1477,4 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" - return self.graph_run_result.state.usage + return self._graph_run_result.state.usage From 0d36dbfb4f835eb69c7f67b5ab1d6d4388844fc5 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 14:29:44 -0700 Subject: [PATCH 25/28] Some minor cleanup of reprs --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++++ pydantic_graph/pydantic_graph/graph.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 693bd9f39..ecf3a5bc6 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1478,3 +1478,7 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" return self._graph_run_result.state.usage + + def __repr__(self): + kws = [f'data={self.data!r}', f'usage={self.usage()}'] + return '<{} {}>'.format(type(self).__name__, ' '.join(kws)) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index e08d59767..ec98c38e4 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -741,7 +741,7 @@ class GraphRunResult(Generic[StateT, DepsT, RunEndT]): """The final result of running a graph.""" result: RunEndT - graph: Graph[StateT, DepsT, RunEndT] + graph: Graph[StateT, DepsT, RunEndT] = field(repr=False) history: list[HistoryStep[StateT, RunEndT]] state: StateT deps: DepsT From 9a676d25a475eb6941e87b63f931d7650357ab9b Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:50:31 -0700 Subject: [PATCH 26/28] Tweak some APIs --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 11 ++++++----- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- pydantic_ai_slim/pydantic_ai/models/__init__.py | 4 ---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 52c970ddb..a509fa399 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -227,10 +227,11 @@ async def run( return await self._make_request(ctx) @asynccontextmanager - async def stream( + async def _stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]], ) -> AsyncIterator[models.StreamedResponse]: + # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public if self._did_stream: raise exceptions.AgentRunError('stream() can only be called once') @@ -314,7 +315,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N model_response: _messages.ModelResponse - _stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) + _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) @@ -358,7 +359,7 @@ async def stream( async def _run_stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> AsyncIterator[_messages.HandleResponseEvent]: - if self._stream is None: + if self._events_iterator is None: # Ensure that the stream is only run once async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: @@ -387,9 +388,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: else: raise exceptions.UnexpectedModelBehavior('Received empty model response') - self._stream = _run_stream() + self._events_iterator = _run_stream() - async for event in self._stream: + async for event in self._events_iterator: yield event async def _handle_tool_calls( diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index ecf3a5bc6..27bab11d0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -636,7 +636,7 @@ async def main(): if isinstance(node, _agent_graph.ModelRequestNode): node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) graph_ctx = agent_run.ctx - async with node.stream(graph_ctx) as streamed_response: + async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] async def stream_to_final( s: models.StreamedResponse, diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index c25f39b7e..e0492d9eb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -262,10 +262,6 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" raise NotImplementedError() - async def stream_events(self) -> AsyncIterator[ModelResponseStreamEvent]: - """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" - return self.__aiter__() - async def stream_debounced_events( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[list[ModelResponseStreamEvent]]: From e7990246feac07ca83d25174b344f98ad1271f53 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 19 Feb 2025 18:07:11 -0700 Subject: [PATCH 27/28] Rename final_result to result and drop DepsT in some places --- docs/agents.md | 2 +- docs/graph.md | 22 +++---- pydantic_ai_slim/pydantic_ai/agent.py | 89 +++++++++++--------------- pydantic_graph/README.md | 6 +- pydantic_graph/pydantic_graph/graph.py | 23 ++++--- tests/graph/test_graph.py | 20 +++--- tests/graph/test_history.py | 11 ++-- tests/graph/test_mermaid.py | 6 +- tests/graph/test_state.py | 6 +- tests/test_agent.py | 4 +- tests/typed_agent.py | 4 +- tests/typed_graph.py | 6 +- 12 files changed, 94 insertions(+), 105 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index dd1c7187a..818e2db88 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -144,7 +144,7 @@ async def main(): End(data=FinalResult(data='Paris', tool_name=None)), ] """ - print(agent_run.final_result.data) + print(agent_run.result.data) #> Paris ``` diff --git a/docs/graph.md b/docs/graph.md index 513e3f480..95a6c0637 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -156,11 +156,11 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -graph_run = fives_graph.run_sync(DivisibleBy5(4)) # (4)! -print(graph_run.result) +result = fives_graph.run_sync(DivisibleBy5(4)) # (4)! +print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in graph_run.history]) +print([item.data_snapshot() for item in result.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` @@ -464,8 +464,8 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - graph_run = await feedback_graph.run(WriteEmail(), state=state) - print(graph_run.result) + result = await feedback_graph.run(WriteEmail(), state=state) + print(result.output) """ Email( subject='Welcome to our tech blog!', @@ -686,7 +686,7 @@ async def main(): #> Node: CountDown() #> Node: CountDown() #> Node: End(data=0) - print('Final result:', run.final_result.result) # (3)! + print('Final result:', run.result.output) # (3)! #> Final result: 0 print('History snapshots:', [step.data_snapshot() for step in run.history]) """ @@ -724,7 +724,7 @@ async def main(): break # (3)! node = await run.next(node) # (4)! - print(run.final_result) # (5)! + print(run.result) # (5)! #> None for step in run.history: # (6)! @@ -738,7 +738,7 @@ async def main(): 2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. 3. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). 4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). -5. Because the run was ended early, we have no final result: +5. Because we did not continue the run until it finished, the `result` is not set. 6. The run's history is still populated with the steps we executed so far. ## Dependency Injection @@ -798,11 +798,11 @@ fives_graph = Graph(nodes=[DivisibleBy5, Increment]) async def main(): with ProcessPoolExecutor() as executor: deps = GraphDeps(executor) - graph_run = await fives_graph.run(DivisibleBy5(3), deps=deps) - print(graph_run.result) + result = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary - print([item.data_snapshot() for item in graph_run.history]) + print([item.data_snapshot() for item in result.history]) """ [ DivisibleBy5(foo=3), diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 27bab11d0..f67885f6b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -13,7 +13,6 @@ from typing_extensions import TypeVar, deprecated from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext -from pydantic_graph.graph import GraphRunResult from . import ( _agent_graph, @@ -215,7 +214,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, ResultDataT]: ... + ) -> AgentRunResult[ResultDataT]: ... @overload async def run( @@ -230,7 +229,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, RunResultDataT]: ... + ) -> AgentRunResult[RunResultDataT]: ... async def run( self, @@ -244,7 +243,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, Any]: + ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. This method builds an internal agent graph (using system prompts, tools and result schemas) and then @@ -291,8 +290,8 @@ async def main(): ) as agent_run: async for _ in agent_run: pass - final_result = agent_run.final_result - assert final_result is not None, 'The graph run should have ended with a final result' + + assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly' return final_result @contextmanager @@ -358,7 +357,7 @@ async def main(): End(data=FinalResult(data='Paris', tool_name=None)), ] ''' - print(agent_run.final_result.data) + print(agent_run.result.data) #> Paris ``` @@ -460,7 +459,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, ResultDataT]: ... + ) -> AgentRunResult[ResultDataT]: ... @overload def run_sync( @@ -475,7 +474,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, RunResultDataT]: ... + ) -> AgentRunResult[RunResultDataT]: ... def run_sync( self, @@ -489,7 +488,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> AgentRunResult[AgentDepsT, Any]: + ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. @@ -1166,8 +1165,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]): You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`. Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an - [`End`][pydantic_graph.nodes.End] is reached, the run finishes and - [`final_result`][pydantic_ai.agent.AgentRun.final_result] becomes available. + [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result] + becomes available. Example: ```python @@ -1207,7 +1206,7 @@ async def main(): End(data=FinalResult(data='Paris', tool_name=None)), ] ''' - print(agent_run.final_result.data) + print(agent_run.result.data) #> Paris ``` @@ -1240,16 +1239,21 @@ def next_node( return self._graph_run.next_node @property - def final_result(self) -> AgentRunResult[AgentDepsT, ResultDataT] | None: + def result(self) -> AgentRunResult[ResultDataT] | None: """The final result of the run if it has ended, otherwise `None`. - Once the run returns an [`End`][pydantic_graph.nodes.End] node, `final_result` is populated + Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. """ - graph_run_result = self._graph_run.final_result + graph_run_result = self._graph_run.result if graph_run_result is None: return None - return AgentRunResult(graph_run_result) + return AgentRunResult( + graph_run_result.output.data, + graph_run_result.output.tool_name, + graph_run_result.state, + self._graph_run.deps.new_message_index, + ) def __aiter__( self, @@ -1345,7 +1349,7 @@ async def main(): End(data=FinalResult(data='Paris', tool_name=None)), ] ''' - print('Final result:', agent_run.final_result.data) + print('Final result:', agent_run.result.data) #> Final result: Paris ``` @@ -1364,47 +1368,36 @@ def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage - def __repr__(self): - final_result = self._graph_run.final_result - result_repr = '' if final_result is None else repr(final_result.result) - kws = [f'result={result_repr}', f'usage={self.usage()}'] - return '<{} {}>'.format(type(self).__name__, ' '.join(kws)) + def __repr__(self) -> str: + result = self._graph_run.result + result_repr = '' if result is None else repr(result.output) + return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' @dataclasses.dataclass -class AgentRunResult(Generic[AgentDepsT, ResultDataT]): +class AgentRunResult(Generic[ResultDataT]): """The final result of an agent run.""" - _graph_run_result: GraphRunResult[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] - ] - - @property - def _result(self) -> FinalResult[ResultDataT]: - return self._graph_run_result.result - - @property - def data(self) -> ResultDataT: - return self._result.data + data: ResultDataT # TODO: rename this to output. I'm putting this off for now mostly to reduce the size of the diff - @property - def _result_tool_name(self) -> str | None: - return self._result.tool_name + _result_tool_name: str | None = dataclasses.field(repr=False) + _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False) + _new_message_index: int = dataclasses.field(repr=False) def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: """Set return content for the result tool. Useful if you want to continue the conversation and want to set the response to the result tool call. """ - if not self._result.tool_name: + if not self._result_tool_name: raise ValueError('Cannot set result tool return content when the return type is `str`.') - messages = deepcopy(self._graph_run_result.state.message_history) + messages = deepcopy(self._state.message_history) last_message = messages[-1] for part in last_message.parts: - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result.tool_name: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: part.content = return_content return messages - raise LookupError(f'No tool call found with tool name {self._result.tool_name!r}.') + raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return the history of _messages. @@ -1421,7 +1414,7 @@ def all_messages(self, *, result_tool_return_content: str | None = None) -> list if result_tool_return_content is not None: return self._set_result_tool_return(result_tool_return_content) else: - return self._graph_run_result.state.message_history + return self._state.message_history def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. @@ -1439,10 +1432,6 @@ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> self.all_messages(result_tool_return_content=result_tool_return_content) ) - @property - def _new_message_index(self) -> int: - return self._graph_run_result.deps.new_message_index - def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return new messages associated with this run. @@ -1477,8 +1466,4 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" - return self._graph_run_result.state.usage - - def __repr__(self): - kws = [f'data={self.data!r}', f'usage={self.usage()}'] - return '<{} {}>'.format(type(self).__name__, ' '.join(kws)) + return self._state.usage diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 3e4ffb24d..29b43cca9 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -50,10 +50,10 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -graph_run = fives_graph.run_sync(DivisibleBy5(4)) -print(graph_run.result) +result = fives_graph.run_sync(DivisibleBy5(4)) +print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in graph_run.history]) +print([item.data_snapshot() for item in result.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index ec98c38e4..a8f3897d4 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -134,7 +134,7 @@ async def run( deps: DepsT = None, infer_name: bool = True, span: LogfireSpan | None = None, - ) -> GraphRunResult[StateT, DepsT, T]: + ) -> GraphRunResult[StateT, T]: """Run the graph from a starting node until it ends. Args: @@ -177,7 +177,7 @@ async def main(): async for _node in graph_run: pass - final_result = graph_run.final_result + final_result = graph_run.result assert final_result is not None, 'GraphRun should have a final result' return final_result @@ -242,7 +242,7 @@ def run_sync( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> GraphRunResult[StateT, DepsT, T]: + ) -> GraphRunResult[StateT, T]: """Synchronously run the graph. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. @@ -659,12 +659,14 @@ def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: return self._next_node @property - def final_result(self) -> GraphRunResult[StateT, DepsT, RunEndT] | None: + def result(self) -> GraphRunResult[StateT, RunEndT] | None: """The final result of the graph run if the run is completed, otherwise `None`.""" if not isinstance(self._next_node, End): return None # The GraphRun has not finished running return GraphRunResult( - self._next_node.data, graph=self.graph, history=self.history, state=self.state, deps=self.deps + self._next_node.data, + state=self.state, + history=self.history, ) async def next( @@ -735,13 +737,14 @@ async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: raise StopAsyncIteration return await self.next(self._next_node) + def __repr__(self) -> str: + return f'"} step={len(self.history) + 1}>' + @dataclass -class GraphRunResult(Generic[StateT, DepsT, RunEndT]): +class GraphRunResult(Generic[StateT, RunEndT]): """The final result of running a graph.""" - result: RunEndT - graph: Graph[StateT, DepsT, RunEndT] = field(repr=False) - history: list[HistoryStep[StateT, RunEndT]] + output: RunEndT state: StateT - deps: DepsT + history: list[HistoryStep[StateT, RunEndT]] = field(repr=False) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 4668bc2c6..91b7a4400 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - graph_run = await my_graph.run(Float2String(3.14)) + result = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 - assert graph_run.result == 8 + assert result.output == 8 assert my_graph.name == 'my_graph' - assert graph_run.history == snapshot( + assert result.history == snapshot( [ NodeStep( state=None, @@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - graph_run = await my_graph.run(Float2String(3.14159)) + result = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert graph_run.result == 42 - assert graph_run.history == snapshot( + assert result.output == 42 + assert result.history == snapshot( [ NodeStep( state=None, @@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) - assert [e.data_snapshot() for e in graph_run.history] == snapshot( + assert [e.data_snapshot() for e in result.history] == snapshot( [ Float2String(input_data=3.14159), String2Length(input_data='3.14159'), @@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: return End(123) g = Graph(nodes=(Foo, Bar)) - graph_run = await g.run(Foo(), deps=Deps(1, 2)) + result = await g.run(Foo(), deps=Deps(1, 2)) - assert graph_run.result == 123 - assert graph_run.history == snapshot( + assert result.output == 123 + assert result.history == snapshot( [ NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index bcd8dca19..da4bcd0d7 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -46,16 +46,17 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ], ) async def test_dump_load_history(graph: Graph[MyState, None, int]): - graph_run = await graph.run(Foo(), state=MyState(1, '')) - assert graph_run.result == snapshot(4) - assert graph_run.history == snapshot( + result = await graph.run(Foo(), state=MyState(1, '')) + assert result.output == snapshot(4) + assert result.state == snapshot(MyState(x=2, y='y')) + assert result.history == snapshot( [ NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) - history_json = graph.dump_history(graph_run.history) + history_json = graph.dump_history(result.history) assert json.loads(history_json) == snapshot( [ { @@ -76,7 +77,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): ] ) history_loaded = graph.load_history(history_json) - assert graph_run.history == history_loaded + assert result.history == history_loaded custom_history = [ { diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 041fe6027..46fb88992 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg async def test_run_graph(): - graph_run = await graph1.run(Foo()) - assert graph_run.result is None - assert graph_run.history == snapshot( + result = await graph1.run(Foo()) + assert result.output is None + assert result.history == snapshot( [ NodeStep( state=None, diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 8c59667ae..77435a1b8 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - graph_run = await graph.run(Foo(), state=state) - assert graph_run.result == snapshot('x=2 y=y') - assert graph_run.history == snapshot( + result = await graph.run(Foo(), state=state) + assert result.output == snapshot('x=2 y=y') + assert result.history == snapshot( [ NodeStep( state=MyState(x=2, y=''), diff --git a/tests/test_agent.py b/tests/test_agent.py index a72cf8499..7d4b41c7f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -279,7 +279,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) - assert result._result.tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] + assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] @@ -312,7 +312,7 @@ def test_result_tool_return_content_no_tool(): result = agent.run_sync('Hello') assert result.data == 0 - result._result.tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] + result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(result_tool_return_content='foobar') diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 6e6607325..280d0795a 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -139,7 +139,7 @@ async def result_validator_wrong(ctx: RunContext[int], result: str) -> str: def run_sync() -> None: result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) - assert_type(result, AgentRunResult[MyDeps, str]) + assert_type(result, AgentRunResult[str]) assert_type(result.data, str) @@ -176,7 +176,7 @@ class Bar: def run_sync3() -> None: result = union_agent.run_sync('testing') - assert_type(result, AgentRunResult[None, Union[Foo, Bar]]) + assert_type(result, AgentRunResult[Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar]) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index deba4dd45..4540ac608 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -109,6 +109,6 @@ def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] - graph_run = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) - assert_type(graph_run.result, int) - assert_type(graph_run.history, list[HistoryStep[MyState, int]]) + result = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(result.output, int) + assert_type(result.history, list[HistoryStep[MyState, int]]) From c7ab89f3690a704d256c4c867abef8ffc443187a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 19 Feb 2025 20:20:13 -0700 Subject: [PATCH 28/28] More cleanup --- pydantic_ai_slim/pydantic_ai/__init__.py | 6 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 30 ++++---- pydantic_ai_slim/pydantic_ai/agent.py | 20 +++++- .../pydantic_ai/models/__init__.py | 69 +------------------ pydantic_ai_slim/pydantic_ai/result.py | 64 +++++++++++++++-- 5 files changed, 102 insertions(+), 87 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 6f28e3047..1b77a420b 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,11 +1,15 @@ from importlib.metadata import version -from .agent import Agent, capture_run_messages +from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError from .tools import RunContext, Tool __all__ = ( 'Agent', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', 'capture_run_messages', 'RunContext', 'Tool', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a509fa399..b080acfc0 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -32,6 +32,16 @@ ToolDefinition, ) +__all__ = ( + 'GraphAgentState', + 'GraphAgentDeps', + 'UserPromptNode', + 'ModelRequestNode', + 'HandleResponseNode', + 'build_run_context', + 'capture_run_messages', +) + _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -98,13 +108,18 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): @dataclasses.dataclass -class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): +class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): user_prompt: str system_prompts: tuple[str, ...] system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + async def run( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> ModelRequestNode[DepsT, NodeRunEndT]: + return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) + async def _get_first_message( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> _messages.ModelRequest: @@ -173,14 +188,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod return messages -@dataclasses.dataclass -class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> ModelRequestNode[DepsT, NodeRunEndT]: - return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) - - async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: @@ -229,11 +236,10 @@ async def run( @asynccontextmanager async def _stream( self, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]], + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[models.StreamedResponse]: # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public - if self._did_stream: - raise exceptions.AgentRunError('stream() can only be called once') + assert not self._did_stream, 'stream() should only be called once per node' model_settings, model_request_parameters = await self._prepare_request(ctx) with _logfire.span('model request', run_step=ctx.state.run_step) as span: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f67885f6b..392443d73 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -25,7 +25,6 @@ result, usage as _usage, ) -from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export from .result import FinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -40,7 +39,24 @@ ToolPrepareFunc, ) -__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy' +# Re-exporting like this improves auto-import behavior in PyCharm +capture_run_messages = _agent_graph.capture_run_messages +EndStrategy = _agent_graph.EndStrategy +HandleResponseNode = _agent_graph.HandleResponseNode +ModelRequestNode = _agent_graph.ModelRequestNode +UserPromptNode = _agent_graph.UserPromptNode + + +__all__ = ( + 'Agent', + 'AgentRun', + 'AgentRunResult', + 'capture_run_messages', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', +) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index e0492d9eb..eef023c97 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -17,7 +17,6 @@ import httpx from typing_extensions import Literal -from .. import _utils, messages as _messages from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent @@ -235,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: This method should be implemented by subclasses to translate the vendor-specific stream of events into pydantic_ai-format events. + + It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes. """ raise NotImplementedError() # noinspection PyUnreachableCode @@ -262,72 +263,6 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" raise NotImplementedError() - async def stream_debounced_events( - self, *, debounce_by: float | None = 0.1 - ) -> AsyncIterator[list[ModelResponseStreamEvent]]: - """Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" - async with _utils.group_by_temporal(self, debounce_by) as group_iter: - async for items in group_iter: - yield items - - async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: - """Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s.""" - - async def _stream_structured_ungrouped() -> AsyncIterator[None]: - # yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting - async for _event in self: - yield None - - async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter: - async for _items in group_iter: - yield self.get() # current state of the response - - async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: - """Stream the response as an async iterable of text.""" - - # Define a "merged" version of the iterator that will yield items that have already been retrieved - # and items that we receive while streaming. We define a dedicated async iterator for this so we can - # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. - async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: - # yields tuples of (text_content, part_index) - # we don't currently make use of the part_index, but in principle this may be useful - # so we retain it here for now to make possible future refactors simpler - msg = self.get() - for i, part in enumerate(msg.parts): - if isinstance(part, _messages.TextPart) and part.content: - yield part.content, i - - async for event in self: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index - elif ( - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index - - async def _stream_text_deltas() -> AsyncIterator[str]: - async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: - async for items in group_iter: - # Note: we are currently just dropping the part index on the group here - yield ''.join([content for content, _ in items]) - - if delta: - async for text in _stream_text_deltas(): - yield text - else: - # a quick benchmark shows it's faster to build up a string with concat when we're - # yielding at each step - deltas: list[str] = [] - async for text in _stream_text_deltas(): - deltas.append(text) - yield ''.join(deltas) - ALLOW_MODEL_REQUESTS = True """Whether to allow requests to models. diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 41cc20b80..7646de5bf 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -9,7 +9,7 @@ import logfire_api from typing_extensions import TypeVar -from . import _result, exceptions, messages as _messages, models +from . import _result, _utils, exceptions, messages as _messages, models from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits @@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu Returns: An async iterable of the response data. """ - self._stream_response.stream_structured(debounce_by=debounce_by) async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): result = await self.validate_structured_result(structured_message, allow_partial=not is_last) yield result @@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = with _logfire.span('response stream text') as lf_span: if delta: - async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): yield text else: combined_validated_text = '' - async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): combined_validated_text = await self._validate_text_result(text) yield combined_validated_text lf_span.set_attribute('combined_text', combined_validated_text) @@ -214,7 +213,7 @@ async def stream_structured( yield msg, False break - async for msg in self._stream_response.stream_structured(debounce_by=debounce_by): + async for msg in self._stream_response_structured(debounce_by=debounce_by): yield msg, False msg = self._stream_response.get() @@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: self._all_messages.append(message) await self._on_complete() + async def _stream_response_structured( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[_messages.ModelResponse]: + async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async for _items in group_iter: + yield self._stream_response.get() + + async def _stream_response_text( + self, *, delta: bool = False, debounce_by: float | None = 0.1 + ) -> AsyncIterator[str]: + """Stream the response as an async iterable of text.""" + + # Define a "merged" version of the iterator that will yield items that have already been retrieved + # and items that we receive while streaming. We define a dedicated async iterator for this so we can + # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: + # yields tuples of (text_content, part_index) + # we don't currently make use of the part_index, but in principle this may be useful + # so we retain it here for now to make possible future refactors simpler + msg = self._stream_response.get() + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.TextPart) and part.content: + yield part.content, i + + async for event in self._stream_response: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index + elif ( + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + # Note: we are currently just dropping the part index on the group here + yield ''.join([content for content, _ in items]) + + if delta: + async for text in _stream_text_deltas(): + yield text + else: + # a quick benchmark shows it's faster to build up a string with concat when we're + # yielding at each step + deltas: list[str] = [] + async for text in _stream_text_deltas(): + deltas.append(text) + yield ''.join(deltas) + @dataclass class FinalResult(Generic[ResultDataT]):