diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 6f28e304..1b77a420 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,11 +1,15 @@ from importlib.metadata import version -from .agent import Agent, capture_run_messages +from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError from .tools import RunContext, Tool __all__ = ( 'Agent', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', 'capture_run_messages', 'RunContext', 'Tool', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a509fa39..f74b63ff 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -32,6 +32,16 @@ ToolDefinition, ) +__all__ = ( + 'GraphAgentState', + 'GraphAgentDeps', + 'UserPromptNode', + 'ModelRequestNode', + 'HandleResponseNode', + 'build_run_context', + 'capture_run_messages', +) + _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -98,13 +108,18 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): @dataclasses.dataclass -class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): +class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): user_prompt: str system_prompts: tuple[str, ...] system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + async def run( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> ModelRequestNode[DepsT, NodeRunEndT]: + return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) + async def _get_first_message( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> _messages.ModelRequest: @@ -173,14 +188,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod return messages -@dataclasses.dataclass -class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> ModelRequestNode[DepsT, NodeRunEndT]: - return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) - - async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: @@ -229,11 +236,10 @@ async def run( @asynccontextmanager async def _stream( self, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]], - ) -> AsyncIterator[models.StreamedResponse]: + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], + ) -> AsyncIterator[result.AgentStream[DepsT, T]]: # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public - if self._did_stream: - raise exceptions.AgentRunError('stream() can only be called once') + assert not self._did_stream, 'stream() should only be called once per node' model_settings, model_request_parameters = await self._prepare_request(ctx) with _logfire.span('model request', run_step=ctx.state.run_step) as span: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f67885f6..392443d7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -25,7 +25,6 @@ result, usage as _usage, ) -from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export from .result import FinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -40,7 +39,24 @@ ToolPrepareFunc, ) -__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy' +# Re-exporting like this improves auto-import behavior in PyCharm +capture_run_messages = _agent_graph.capture_run_messages +EndStrategy = _agent_graph.EndStrategy +HandleResponseNode = _agent_graph.HandleResponseNode +ModelRequestNode = _agent_graph.ModelRequestNode +UserPromptNode = _agent_graph.UserPromptNode + + +__all__ = ( + 'Agent', + 'AgentRun', + 'AgentRunResult', + 'capture_run_messages', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', +) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index e0492d9e..eef023c9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -17,7 +17,6 @@ import httpx from typing_extensions import Literal -from .. import _utils, messages as _messages from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent @@ -235,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: This method should be implemented by subclasses to translate the vendor-specific stream of events into pydantic_ai-format events. + + It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes. """ raise NotImplementedError() # noinspection PyUnreachableCode @@ -262,72 +263,6 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" raise NotImplementedError() - async def stream_debounced_events( - self, *, debounce_by: float | None = 0.1 - ) -> AsyncIterator[list[ModelResponseStreamEvent]]: - """Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" - async with _utils.group_by_temporal(self, debounce_by) as group_iter: - async for items in group_iter: - yield items - - async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: - """Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s.""" - - async def _stream_structured_ungrouped() -> AsyncIterator[None]: - # yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting - async for _event in self: - yield None - - async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter: - async for _items in group_iter: - yield self.get() # current state of the response - - async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: - """Stream the response as an async iterable of text.""" - - # Define a "merged" version of the iterator that will yield items that have already been retrieved - # and items that we receive while streaming. We define a dedicated async iterator for this so we can - # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. - async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: - # yields tuples of (text_content, part_index) - # we don't currently make use of the part_index, but in principle this may be useful - # so we retain it here for now to make possible future refactors simpler - msg = self.get() - for i, part in enumerate(msg.parts): - if isinstance(part, _messages.TextPart) and part.content: - yield part.content, i - - async for event in self: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index - elif ( - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index - - async def _stream_text_deltas() -> AsyncIterator[str]: - async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: - async for items in group_iter: - # Note: we are currently just dropping the part index on the group here - yield ''.join([content for content, _ in items]) - - if delta: - async for text in _stream_text_deltas(): - yield text - else: - # a quick benchmark shows it's faster to build up a string with concat when we're - # yielding at each step - deltas: list[str] = [] - async for text in _stream_text_deltas(): - deltas.append(text) - yield ''.join(deltas) - ALLOW_MODEL_REQUESTS = True """Whether to allow requests to models. diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 41cc20b8..7646de5b 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -9,7 +9,7 @@ import logfire_api from typing_extensions import TypeVar -from . import _result, exceptions, messages as _messages, models +from . import _result, _utils, exceptions, messages as _messages, models from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits @@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu Returns: An async iterable of the response data. """ - self._stream_response.stream_structured(debounce_by=debounce_by) async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): result = await self.validate_structured_result(structured_message, allow_partial=not is_last) yield result @@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = with _logfire.span('response stream text') as lf_span: if delta: - async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): yield text else: combined_validated_text = '' - async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): combined_validated_text = await self._validate_text_result(text) yield combined_validated_text lf_span.set_attribute('combined_text', combined_validated_text) @@ -214,7 +213,7 @@ async def stream_structured( yield msg, False break - async for msg in self._stream_response.stream_structured(debounce_by=debounce_by): + async for msg in self._stream_response_structured(debounce_by=debounce_by): yield msg, False msg = self._stream_response.get() @@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: self._all_messages.append(message) await self._on_complete() + async def _stream_response_structured( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[_messages.ModelResponse]: + async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async for _items in group_iter: + yield self._stream_response.get() + + async def _stream_response_text( + self, *, delta: bool = False, debounce_by: float | None = 0.1 + ) -> AsyncIterator[str]: + """Stream the response as an async iterable of text.""" + + # Define a "merged" version of the iterator that will yield items that have already been retrieved + # and items that we receive while streaming. We define a dedicated async iterator for this so we can + # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: + # yields tuples of (text_content, part_index) + # we don't currently make use of the part_index, but in principle this may be useful + # so we retain it here for now to make possible future refactors simpler + msg = self._stream_response.get() + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.TextPart) and part.content: + yield part.content, i + + async for event in self._stream_response: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index + elif ( + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + # Note: we are currently just dropping the part index on the group here + yield ''.join([content for content, _ in items]) + + if delta: + async for text in _stream_text_deltas(): + yield text + else: + # a quick benchmark shows it's faster to build up a string with concat when we're + # yielding at each step + deltas: list[str] = [] + async for text in _stream_text_deltas(): + deltas.append(text) + yield ''.join(deltas) + @dataclass class FinalResult(Generic[ResultDataT]):