Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 20, 2025
1 parent e799024 commit 241f179
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 88 deletions.
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from importlib.metadata import version

from .agent import Agent, capture_run_messages
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .tools import RunContext, Tool

__all__ = (
'Agent',
'EndStrategy',
'HandleResponseNode',
'ModelRequestNode',
'UserPromptNode',
'capture_run_messages',
'RunContext',
'Tool',
Expand Down
32 changes: 19 additions & 13 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
ToolDefinition,
)

__all__ = (
'GraphAgentState',
'GraphAgentDeps',
'UserPromptNode',
'ModelRequestNode',
'HandleResponseNode',
'build_run_context',
'capture_run_messages',
)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

# while waiting for https://github.com/pydantic/logfire/issues/745
Expand Down Expand Up @@ -98,13 +108,18 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):


@dataclasses.dataclass
class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
user_prompt: str

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> ModelRequestNode[DepsT, NodeRunEndT]:
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))

async def _get_first_message(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> _messages.ModelRequest:
Expand Down Expand Up @@ -173,14 +188,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod
return messages


@dataclasses.dataclass
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
) -> ModelRequestNode[DepsT, NodeRunEndT]:
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))


async def _prepare_request_parameters(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
) -> models.ModelRequestParameters:
Expand Down Expand Up @@ -229,11 +236,10 @@ async def run(
@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]],
) -> AsyncIterator[models.StreamedResponse]:
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
if self._did_stream:
raise exceptions.AgentRunError('stream() can only be called once')
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
Expand Down
20 changes: 18 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
result,
usage as _usage,
)
from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
from .result import FinalResult, ResultDataT, StreamedRunResult
from .settings import ModelSettings, merge_model_settings
from .tools import (
Expand All @@ -40,7 +39,24 @@
ToolPrepareFunc,
)

__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy'
# Re-exporting like this improves auto-import behavior in PyCharm
capture_run_messages = _agent_graph.capture_run_messages
EndStrategy = _agent_graph.EndStrategy
HandleResponseNode = _agent_graph.HandleResponseNode
ModelRequestNode = _agent_graph.ModelRequestNode
UserPromptNode = _agent_graph.UserPromptNode


__all__ = (
'Agent',
'AgentRun',
'AgentRunResult',
'capture_run_messages',
'EndStrategy',
'HandleResponseNode',
'ModelRequestNode',
'UserPromptNode',
)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

Expand Down
69 changes: 2 additions & 67 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import httpx
from typing_extensions import Literal

from .. import _utils, messages as _messages
from .._parts_manager import ModelResponsePartsManager
from ..exceptions import UserError
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
Expand Down Expand Up @@ -235,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
This method should be implemented by subclasses to translate the vendor-specific stream of events into
pydantic_ai-format events.
It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
"""
raise NotImplementedError()
# noinspection PyUnreachableCode
Expand Down Expand Up @@ -262,72 +263,6 @@ def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
raise NotImplementedError()

async def stream_debounced_events(
self, *, debounce_by: float | None = 0.1
) -> AsyncIterator[list[ModelResponseStreamEvent]]:
"""Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
async for items in group_iter:
yield items

async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
"""Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s."""

async def _stream_structured_ungrouped() -> AsyncIterator[None]:
# yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting
async for _event in self:
yield None

async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter:
async for _items in group_iter:
yield self.get() # current state of the response

async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
"""Stream the response as an async iterable of text."""

# Define a "merged" version of the iterator that will yield items that have already been retrieved
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
# yields tuples of (text_content, part_index)
# we don't currently make use of the part_index, but in principle this may be useful
# so we retain it here for now to make possible future refactors simpler
msg = self.get()
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.TextPart) and part.content:
yield part.content, i

async for event in self:
if (
isinstance(event, _messages.PartStartEvent)
and isinstance(event.part, _messages.TextPart)
and event.part.content
):
yield event.part.content, event.index
elif (
isinstance(event, _messages.PartDeltaEvent)
and isinstance(event.delta, _messages.TextPartDelta)
and event.delta.content_delta
):
yield event.delta.content_delta, event.index

async def _stream_text_deltas() -> AsyncIterator[str]:
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
async for items in group_iter:
# Note: we are currently just dropping the part index on the group here
yield ''.join([content for content, _ in items])

if delta:
async for text in _stream_text_deltas():
yield text
else:
# a quick benchmark shows it's faster to build up a string with concat when we're
# yielding at each step
deltas: list[str] = []
async for text in _stream_text_deltas():
deltas.append(text)
yield ''.join(deltas)


ALLOW_MODEL_REQUESTS = True
"""Whether to allow requests to models.
Expand Down
64 changes: 59 additions & 5 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logfire_api
from typing_extensions import TypeVar

from . import _result, exceptions, messages as _messages, models
from . import _result, _utils, exceptions, messages as _messages, models
from .tools import AgentDepsT, RunContext
from .usage import Usage, UsageLimits

Expand Down Expand Up @@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu
Returns:
An async iterable of the response data.
"""
self._stream_response.stream_structured(debounce_by=debounce_by)
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
yield result
Expand All @@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =

with _logfire.span('response stream text') as lf_span:
if delta:
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
yield text
else:
combined_validated_text = ''
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
combined_validated_text = await self._validate_text_result(text)
yield combined_validated_text
lf_span.set_attribute('combined_text', combined_validated_text)
Expand All @@ -214,7 +213,7 @@ async def stream_structured(
yield msg, False
break

async for msg in self._stream_response.stream_structured(debounce_by=debounce_by):
async for msg in self._stream_response_structured(debounce_by=debounce_by):
yield msg, False

msg = self._stream_response.get()
Expand Down Expand Up @@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
self._all_messages.append(message)
await self._on_complete()

async def _stream_response_structured(
self, *, debounce_by: float | None = 0.1
) -> AsyncIterator[_messages.ModelResponse]:
async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
async for _items in group_iter:
yield self._stream_response.get()

async def _stream_response_text(
self, *, delta: bool = False, debounce_by: float | None = 0.1
) -> AsyncIterator[str]:
"""Stream the response as an async iterable of text."""

# Define a "merged" version of the iterator that will yield items that have already been retrieved
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
# yields tuples of (text_content, part_index)
# we don't currently make use of the part_index, but in principle this may be useful
# so we retain it here for now to make possible future refactors simpler
msg = self._stream_response.get()
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.TextPart) and part.content:
yield part.content, i

async for event in self._stream_response:
if (
isinstance(event, _messages.PartStartEvent)
and isinstance(event.part, _messages.TextPart)
and event.part.content
):
yield event.part.content, event.index
elif (
isinstance(event, _messages.PartDeltaEvent)
and isinstance(event.delta, _messages.TextPartDelta)
and event.delta.content_delta
):
yield event.delta.content_delta, event.index

async def _stream_text_deltas() -> AsyncIterator[str]:
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
async for items in group_iter:
# Note: we are currently just dropping the part index on the group here
yield ''.join([content for content, _ in items])

if delta:
async for text in _stream_text_deltas():
yield text
else:
# a quick benchmark shows it's faster to build up a string with concat when we're
# yielding at each step
deltas: list[str] = []
async for text in _stream_text_deltas():
deltas.append(text)
yield ''.join(deltas)


@dataclass
class FinalResult(Generic[ResultDataT]):
Expand Down

0 comments on commit 241f179

Please sign in to comment.