Skip to content

Commit

Permalink
Merge branch 'main' into dmontagu/graph-run-streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 25, 2025
2 parents 5008c9e + bb41987 commit 2cfa693
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):


@dataclasses.dataclass
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
user_prompt: str | Sequence[_messages.UserContent]

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> ModelRequestNode[DepsT, NodeRunEndT]:
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))

async def _get_first_message(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> _messages.ModelRequest:
run_context = build_run_context(ctx)
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
Expand Down Expand Up @@ -217,7 +217,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:


@dataclasses.dataclass
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
"""Make a request to the model using the last message in state.message_history."""

request: _messages.ModelRequest
Expand Down Expand Up @@ -339,7 +339,7 @@ def _finish_handling(


@dataclasses.dataclass
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
"""Process a model response, and decide whether to end the run or make a new request."""

model_response: _messages.ModelResponse
Expand All @@ -361,7 +361,7 @@ async def run(

@asynccontextmanager
async def stream(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
"""Process the model response and yield events for the start and end of each function tool call."""
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
Expand All @@ -386,7 +386,7 @@ async def stream(
handle_span.message = f'handle model response -> {tool_responses_str}'

async def _run_stream(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[_messages.HandleResponseEvent]:
if self._events_iterator is None:
# Ensure that the stream is only run once
Expand Down Expand Up @@ -690,7 +690,7 @@ def get_captured_run_messages() -> _RunMessages:

def build_agent_graph(
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
nodes = (
UserPromptNode[DepsT],
Expand Down

0 comments on commit 2cfa693

Please sign in to comment.