diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e4b8b0ca..859cf2a8 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -110,7 +110,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): @dataclasses.dataclass -class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): +class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC): user_prompt: str | Sequence[_messages.UserContent] system_prompts: tuple[str, ...] @@ -118,12 +118,12 @@ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeR system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT]: return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) async def _get_first_message( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> _messages.ModelRequest: run_context = build_run_context(ctx) history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) @@ -217,7 +217,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: @dataclasses.dataclass -class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): +class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """Make a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest @@ -339,7 +339,7 @@ def _finish_handling( @dataclasses.dataclass -class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): +class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse @@ -361,7 +361,7 @@ async def run( @asynccontextmanager async def stream( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: """Process the model response and yield events for the start and end of each function tool call.""" with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: @@ -386,7 +386,7 @@ async def stream( handle_span.message = f'handle model response -> {tool_responses_str}' async def _run_stream( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: if self._events_iterator is None: # Ensure that the stream is only run once @@ -690,7 +690,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]: +) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT],