From bb419871bad435df24998e7715ca5bd4bc54fc7c Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:01:04 -0800 Subject: [PATCH] Fix agent graph types (#983) --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 87e2620c..ec9f1656 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -108,7 +108,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): @dataclasses.dataclass -class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): +class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC): user_prompt: str | Sequence[_messages.UserContent] system_prompts: tuple[str, ...] @@ -116,12 +116,12 @@ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeR system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT]: return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) async def _get_first_message( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> _messages.ModelRequest: run_context = build_run_context(ctx) history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) @@ -215,7 +215,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: @dataclasses.dataclass -class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): +class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """Make a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest @@ -319,7 +319,7 @@ def _finish_handling( @dataclasses.dataclass -class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): +class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse @@ -341,7 +341,7 @@ async def run( @asynccontextmanager async def stream( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: """Process the model response and yield events for the start and end of each function tool call.""" with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: @@ -366,7 +366,7 @@ async def stream( handle_span.message = f'handle model response -> {tool_responses_str}' async def _run_stream( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: if self._events_iterator is None: # Ensure that the stream is only run once @@ -670,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]: +) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT],