Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GraphRun object to make use of next more ergonomic #833

Merged
merged 31 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
57905c0
Add GraphRun class
dmontagu Feb 11, 2025
1a09378
Minor improvements
dmontagu Feb 11, 2025
0c3b48d
Move definition of MarkFinalResult to result module
dmontagu Feb 11, 2025
9280adf
A bit more clean-up
dmontagu Feb 11, 2025
7d55afd
Update call_id logic
dmontagu Feb 11, 2025
63df8ba
Minor fixes
dmontagu Feb 11, 2025
6c65095
Update some things
dmontagu Feb 11, 2025
db56e31
Update some comments etc.
dmontagu Feb 12, 2025
9af98e8
Undo kind changes
dmontagu Feb 12, 2025
2100a1a
Merge branch 'main' into dmontagu/graph-run-object
dmontagu Feb 12, 2025
78e85d6
Introduce auxiliary types
dmontagu Feb 12, 2025
e0c716b
Merge main
dmontagu Feb 17, 2025
ef8895a
Address some feedback
dmontagu Feb 18, 2025
13e3b86
result -> node
dmontagu Feb 18, 2025
a08aafa
Rename MarkFinalResult to FinalResult
dmontagu Feb 18, 2025
ff6f699
Remove GraphRunner/AgentRunner and add .iter() API
dmontagu Feb 18, 2025
41bb069
Make result private
dmontagu Feb 18, 2025
b565088
Reduce diff to main and add some docstrings
dmontagu Feb 18, 2025
8d2c74e
Add more docstrings
dmontagu Feb 18, 2025
4bb67a5
Add more docs
dmontagu Feb 18, 2025
a6e6445
Fix various docs references
dmontagu Feb 18, 2025
007d8ca
Fix final docs references
dmontagu Feb 18, 2025
6d532c1
Address some feedback
dmontagu Feb 18, 2025
0745ba9
Update docs
dmontagu Feb 18, 2025
8d86b3a
Fix docs build
dmontagu Feb 18, 2025
bdb5f77
Make the graph_run_result private on AgentRunResult
dmontagu Feb 18, 2025
0d36dbf
Some minor cleanup of reprs
dmontagu Feb 18, 2025
aa8b36a
Merge branch 'main' into dmontagu/graph-run-object
dmontagu Feb 19, 2025
9a676d2
Tweak some APIs
dmontagu Feb 19, 2025
e799024
Rename final_result to result and drop DepsT in some places
dmontagu Feb 20, 2025
c7ab89f
More cleanup
dmontagu Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from importlib.metadata import version

from .agent import Agent, capture_run_messages
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .tools import RunContext, Tool

__all__ = (
'Agent',
'EndStrategy',
'HandleResponseNode',
'ModelRequestNode',
'UserPromptNode',
'capture_run_messages',
'RunContext',
'Tool',
Expand Down
30 changes: 18 additions & 12 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be public since the nodes you get from iterating are now kind of public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's enough to leave the module private and re-export the node classes, but I'm also okay to make the module public.

Copy link
Contributor Author

@dmontagu dmontagu Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've done the re-export thing now, and it works for me in pycharm, but it might be worth double-checking on your end

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
ToolDefinition,
)

__all__ = (
'GraphAgentState',
'GraphAgentDeps',
'UserPromptNode',
'ModelRequestNode',
'HandleResponseNode',
'build_run_context',
'capture_run_messages',
)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

# while waiting for https://github.com/pydantic/logfire/issues/745
Expand Down Expand Up @@ -98,13 +108,18 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):


@dataclasses.dataclass
class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
user_prompt: str

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> ModelRequestNode[DepsT, NodeRunEndT]:
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))

async def _get_first_message(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> _messages.ModelRequest:
Expand Down Expand Up @@ -173,14 +188,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod
return messages


@dataclasses.dataclass
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> ModelRequestNode[DepsT, NodeRunEndT]:
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))


async def _prepare_request_parameters(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
) -> models.ModelRequestParameters:
Expand Down Expand Up @@ -229,11 +236,10 @@ async def run(
@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]],
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[models.StreamedResponse]:
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
if self._did_stream:
raise exceptions.AgentRunError('stream() can only be called once')
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
Expand Down
20 changes: 18 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
result,
usage as _usage,
)
from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
from .result import FinalResult, ResultDataT, StreamedRunResult
from .settings import ModelSettings, merge_model_settings
from .tools import (
Expand All @@ -40,7 +39,24 @@
ToolPrepareFunc,
)

__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy'
# Re-exporting like this improves auto-import behavior in PyCharm
capture_run_messages = _agent_graph.capture_run_messages
EndStrategy = _agent_graph.EndStrategy
HandleResponseNode = _agent_graph.HandleResponseNode
ModelRequestNode = _agent_graph.ModelRequestNode
UserPromptNode = _agent_graph.UserPromptNode
Comment on lines +42 to +47
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelcolvin let me know if you think there's a better way to do this. (I tried other things I could think of and this worked best.)



__all__ = (
'Agent',
'AgentRun',
'AgentRunResult',
'capture_run_messages',
'EndStrategy',
'HandleResponseNode',
'ModelRequestNode',
'UserPromptNode',
)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

Expand Down
69 changes: 2 additions & 67 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import httpx
from typing_extensions import Literal

from .. import _utils, messages as _messages
from .._parts_manager import ModelResponsePartsManager
from ..exceptions import UserError
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
Expand Down Expand Up @@ -235,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

This method should be implemented by subclasses to translate the vendor-specific stream of events into
pydantic_ai-format events.

It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
"""
raise NotImplementedError()
# noinspection PyUnreachableCode
Expand Down Expand Up @@ -262,72 +263,6 @@ def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
raise NotImplementedError()

async def stream_debounced_events(
self, *, debounce_by: float | None = 0.1
) -> AsyncIterator[list[ModelResponseStreamEvent]]:
"""Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
async for items in group_iter:
yield items

async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
"""Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s."""

async def _stream_structured_ungrouped() -> AsyncIterator[None]:
# yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting
async for _event in self:
yield None

async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter:
async for _items in group_iter:
yield self.get() # current state of the response

async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
"""Stream the response as an async iterable of text."""

# Define a "merged" version of the iterator that will yield items that have already been retrieved
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
# yields tuples of (text_content, part_index)
# we don't currently make use of the part_index, but in principle this may be useful
# so we retain it here for now to make possible future refactors simpler
msg = self.get()
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.TextPart) and part.content:
yield part.content, i

async for event in self:
if (
isinstance(event, _messages.PartStartEvent)
and isinstance(event.part, _messages.TextPart)
and event.part.content
):
yield event.part.content, event.index
elif (
isinstance(event, _messages.PartDeltaEvent)
and isinstance(event.delta, _messages.TextPartDelta)
and event.delta.content_delta
):
yield event.delta.content_delta, event.index

async def _stream_text_deltas() -> AsyncIterator[str]:
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
async for items in group_iter:
# Note: we are currently just dropping the part index on the group here
yield ''.join([content for content, _ in items])

if delta:
async for text in _stream_text_deltas():
yield text
else:
# a quick benchmark shows it's faster to build up a string with concat when we're
# yielding at each step
deltas: list[str] = []
async for text in _stream_text_deltas():
deltas.append(text)
yield ''.join(deltas)


ALLOW_MODEL_REQUESTS = True
"""Whether to allow requests to models.
Expand Down
64 changes: 59 additions & 5 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logfire_api
from typing_extensions import TypeVar

from . import _result, exceptions, messages as _messages, models
from . import _result, _utils, exceptions, messages as _messages, models
from .tools import AgentDepsT, RunContext
from .usage import Usage, UsageLimits

Expand Down Expand Up @@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu
Returns:
An async iterable of the response data.
"""
self._stream_response.stream_structured(debounce_by=debounce_by)
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
yield result
Expand All @@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =

with _logfire.span('response stream text') as lf_span:
if delta:
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
yield text
else:
combined_validated_text = ''
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
combined_validated_text = await self._validate_text_result(text)
yield combined_validated_text
lf_span.set_attribute('combined_text', combined_validated_text)
Expand All @@ -214,7 +213,7 @@ async def stream_structured(
yield msg, False
break

async for msg in self._stream_response.stream_structured(debounce_by=debounce_by):
async for msg in self._stream_response_structured(debounce_by=debounce_by):
yield msg, False

msg = self._stream_response.get()
Expand Down Expand Up @@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
self._all_messages.append(message)
await self._on_complete()

async def _stream_response_structured(
self, *, debounce_by: float | None = 0.1
) -> AsyncIterator[_messages.ModelResponse]:
async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
async for _items in group_iter:
yield self._stream_response.get()

async def _stream_response_text(
self, *, delta: bool = False, debounce_by: float | None = 0.1
) -> AsyncIterator[str]:
"""Stream the response as an async iterable of text."""

# Define a "merged" version of the iterator that will yield items that have already been retrieved
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
# yields tuples of (text_content, part_index)
# we don't currently make use of the part_index, but in principle this may be useful
# so we retain it here for now to make possible future refactors simpler
msg = self._stream_response.get()
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.TextPart) and part.content:
yield part.content, i

async for event in self._stream_response:
if (
isinstance(event, _messages.PartStartEvent)
and isinstance(event.part, _messages.TextPart)
and event.part.content
):
yield event.part.content, event.index
elif (
isinstance(event, _messages.PartDeltaEvent)
and isinstance(event.delta, _messages.TextPartDelta)
and event.delta.content_delta
):
yield event.delta.content_delta, event.index

async def _stream_text_deltas() -> AsyncIterator[str]:
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
async for items in group_iter:
# Note: we are currently just dropping the part index on the group here
yield ''.join([content for content, _ in items])

if delta:
async for text in _stream_text_deltas():
yield text
else:
# a quick benchmark shows it's faster to build up a string with concat when we're
# yielding at each step
deltas: list[str] = []
async for text in _stream_text_deltas():
deltas.append(text)
yield ''.join(deltas)


@dataclass
class FinalResult(Generic[ResultDataT]):
Expand Down