Skip to content

Commit

Permalink
Replace RunResult with AgentRun
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 5, 2025
1 parent d46ea2d commit e421ba0
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 174 deletions.
227 changes: 179 additions & 48 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import asyncio
import dataclasses
import inspect
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from collections.abc import AsyncIterator, Awaitable, Generator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
from copy import deepcopy
from types import FrameType
from typing import Any, Callable, Generic, cast, final, overload

import logfire_api
from typing_extensions import TypeVar, deprecated
from typing_extensions import Self, TypeVar, deprecated

from pydantic_graph import Graph, GraphRunContext, HistoryStep
from pydantic_graph import BaseNode, Graph, GraphRun, GraphRunContext, HistoryStep
from pydantic_graph.nodes import End

from . import (
Expand All @@ -25,7 +26,7 @@
result,
usage as _usage,
)
from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
from ._agent_graph import EndStrategy, MarkFinalResult, capture_run_messages # imported for re-export
from .result import ResultDataT
from .settings import ModelSettings, merge_model_settings
from .tools import (
Expand Down Expand Up @@ -214,7 +215,7 @@ async def run(
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultDataT]: ...
) -> AgentRun[AgentDepsT, ResultDataT]: ...

@overload
async def run(
Expand All @@ -229,7 +230,7 @@ async def run(
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[RunResultDataT]: ...
) -> AgentRun[AgentDepsT, ResultDataT]: ...

async def run(
self,
Expand All @@ -243,7 +244,7 @@ async def run(
usage: _usage.Usage | None = None,
result_type: type[RunResultDataT] | None = None,
infer_name: bool = True,
) -> result.RunResult[Any]:
) -> AgentRun[AgentDepsT, ResultDataT]:
"""Run the agent with a user prompt in async mode.
Example:
Expand Down Expand Up @@ -305,54 +306,47 @@ async def main():
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()

with _logfire.span(
# Build the deps object for the graph
run_span = _logfire.span(
'{agent_name} run {prompt=}',
prompt=user_prompt,
agent=self,
model_name=model_used.name() if model_used else 'no-model',
agent_name=self.name or 'agent',
) as run_span:
# Build the deps object for the graph
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
result_schema=result_schema,
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
result_validators=result_validators,
function_tools=self._function_tools,
run_span=run_span,
)

start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)
)
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
result_schema=result_schema,
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
result_validators=result_validators,
function_tools=self._function_tools,
run_span=run_span,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)

# Actually run
graph_run = await graph.run(
# Actually run
# TODO: Make this method non-async and remove the next await
# That way, users can decide whether to "await" the run, or iterate over it
return await AgentRun(
graph.run(
start_node,
state=state,
deps=graph_deps,
infer_name=False,
span=run_span,
)
end_result = graph_run.result

# Build final run result
# We don't do any advanced checking if the data is actually from a final result or not
return result.RunResult(
state.message_history,
new_message_index,
end_result.data,
end_result.tool_name,
state.usage,
)

@overload
Expand All @@ -367,7 +361,7 @@ def run_sync(
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultDataT]: ...
) -> AgentRun[AgentDepsT, ResultDataT]: ...

@overload
def run_sync(
Expand All @@ -382,7 +376,7 @@ def run_sync(
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[RunResultDataT]: ...
) -> AgentRun[AgentDepsT, ResultDataT]: ...

def run_sync(
self,
Expand All @@ -396,7 +390,7 @@ def run_sync(
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[Any]:
) -> AgentRun[AgentDepsT, ResultDataT]:
"""Run the agent with a user prompt synchronously.
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
Expand Down Expand Up @@ -1040,7 +1034,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]:

def _build_graph(
self, result_type: type[RunResultDataT] | None
) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[Any]]:
return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)

def _build_stream_graph(
Expand All @@ -1059,3 +1053,140 @@ def _prepare_result_schema(
)
else:
return self._result_schema # pyright: ignore[reportReturnType]


@dataclasses.dataclass
class AgentRun(Generic[AgentDepsT, ResultDataT]):
# TODO: Should this go into the `result` module? And replace `result.RunResult`?
graph_run: GraphRun[
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT]
]

def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
if result_tool_return_content is not None:
return self._set_result_tool_return(result_tool_return_content)
else:
return self.graph_run.state.message_history

def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
Args:
result_tool_return_content: The return content of the tool call to set in the last message.
This provides a convenient way to modify the content of the result tool call if you want to continue
the conversation and want to set the response to the result tool call. If `None`, the last message will
not be modified.
Returns:
JSON bytes representing the messages.
"""
return _messages.ModelMessagesTypeAdapter.dump_json(
self.all_messages(result_tool_return_content=result_tool_return_content)
)

@property
def _new_message_index(self) -> int:
return self.graph_run.deps.new_message_index

def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]

def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
"""Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
Args:
result_tool_return_content: The return content of the tool call to set in the last message.
This provides a convenient way to modify the content of the result tool call if you want to continue
the conversation and want to set the response to the result tool call. If `None`, the last message will
not be modified.
Returns:
JSON bytes representing the new messages.
"""
return _messages.ModelMessagesTypeAdapter.dump_json(
self.new_messages(result_tool_return_content=result_tool_return_content)
)

def usage(self) -> _usage.Usage:
return self.graph_run.state.usage

def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
"""Set return content for the result tool.
Useful if you want to continue the conversation and want to set the response to the result tool call.
"""
if not self.result.tool_name:
raise ValueError('Cannot set result tool return content when the return type is `str`.')
messages = deepcopy(self.graph_run.state.message_history)
last_message = messages[-1]
for part in last_message.parts:
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.result.tool_name:
part.content = return_content
return messages
raise LookupError(f'No tool call found with tool name {self.result.tool_name!r}.')

@property
def is_ended(self):
return self.graph_run.is_ended

@property
def result(self) -> MarkFinalResult[ResultDataT]:
return self.graph_run.result

@property
def _result_tool_name(self) -> str | None:
return self.graph_run.result.tool_name

@property
def data(self) -> ResultDataT:
return self.result.data

async def next(
self,
node: BaseNode[
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT]
],
) -> (
BaseNode[
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT]
]
| End[MarkFinalResult[ResultDataT]]
):
return await self.graph_run.next(node)

def __await__(self) -> Generator[Any, Any, Self]:
"""Run the graph until it ends, and return the final result."""

async def _run():
await self.graph_run
return self

return _run().__await__()

def __enter__(self) -> Self:
self.graph_run.__enter__()
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.graph_run.__exit__(exc_type, exc_val, exc_tb)

def __aiter__(
self,
) -> AsyncIterator[
BaseNode[
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT]
]
| End[MarkFinalResult[ResultDataT]]
]:
return self

async def __anext__(
self,
) -> (
BaseNode[
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], MarkFinalResult[ResultDataT]
]
| End[MarkFinalResult[ResultDataT]]
):
"""Use the last returned node as the input to `Graph.next`."""
return await self.graph_run.__anext__()
17 changes: 14 additions & 3 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def run(
state: StateT = None,
deps: DepsT = None,
infer_name: bool = True,
span: LogfireSpan | None = None,
) -> GraphRun[StateT, DepsT, T]:
"""Run the graph from a starting node until it ends.
Expand All @@ -142,6 +143,7 @@ def run(
state: The initial state of the graph.
deps: The dependencies of the graph.
infer_name: Whether to infer the graph name from the calling frame.
span: The span to use for the graph run. If not provided, a new span will be created.
Returns:
The result type from ending the run and the history of the run.
Expand Down Expand Up @@ -171,7 +173,13 @@ async def main():
self._infer_name(inspect.currentframe())

return GraphRun[StateT, DepsT, T](
self, start_node, history=[], state=state, deps=deps, auto_instrument=self._auto_instrument
self,
start_node,
history=[],
state=state,
deps=deps,
auto_instrument=self._auto_instrument,
span=span,
)

def run_sync(
Expand Down Expand Up @@ -514,17 +522,18 @@ def __init__(
state: StateT,
deps: DepsT,
auto_instrument: bool,
span: LogfireSpan | None = None,
):
self.graph = graph
self.history = history
self.state = state
self.deps = deps
self._auto_instrument = auto_instrument
self._span = span

self._next_node = first_node
self._started: bool = False
self._result: End[RunEndT] | None = None
self._span: LogfireSpan | None = None

@property
def is_ended(self):
Expand Down Expand Up @@ -568,8 +577,10 @@ def __enter__(self) -> typing_extensions.Self:
if self._started:
raise exceptions.GraphRuntimeError('A GraphRun can only be started once.')

if self._auto_instrument:
if self._auto_instrument and self._span is None:
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)

if self._span is not None:
self._span.__enter__()

self._started = True
Expand Down
Loading

0 comments on commit e421ba0

Please sign in to comment.