From 4c03c1736ad15eba02bc7472f339decebc2e0c4d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 12 Nov 2024 10:35:56 +0000 Subject: [PATCH] remove `Either` from `_handle_model_response` and `_handle_streamed_model_response` (#39) --- pydantic_ai/agent.py | 58 ++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 133141979..dbeac1565 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -165,16 +165,17 @@ async def run( with _logfire.span('handle model response') as handle_span: either = await self._handle_model_response(model_response, deps) - if left := either.left: - # left means return a streamed result + if isinstance(either, _MarkFinalResult): + # we have a final result, end the conversation + result_data = either.data run_span.set_attribute('all_messages', messages) run_span.set_attribute('cost', cost) - handle_span.set_attribute('result', left.value) + handle_span.set_attribute('result', result_data) handle_span.message = 'handle model response -> final result' - return result.RunResult(messages, new_message_index, left.value, cost) + return result.RunResult(messages, new_message_index, result_data, cost) else: - # right means continue the conversation - tool_responses = either.right + # continue the conversation + tool_responses = either handle_span.set_attribute('tool_responses', tool_responses) response_msgs = ' '.join(m.role for m in tool_responses) handle_span.message = f'handle model response -> {response_msgs}' @@ -256,9 +257,8 @@ async def run_stream( with _logfire.span('handle model response') as handle_span: either = await self._handle_streamed_model_response(model_response, deps) - if left := either.left: - # left means return a streamed result - result_stream = left.value + if isinstance(either, _MarkFinalResult): + result_stream = either.data run_span.set_attribute('all_messages', messages) handle_span.set_attribute('result_type', result_stream.__class__.__name__) handle_span.message = 'handle model response -> final result' @@ -273,8 +273,7 @@ async def run_stream( ) return else: - # right means continue the conversation - tool_responses = either.right + tool_responses = either handle_span.set_attribute('tool_responses', tool_responses) response_msgs = ' '.join(m.role for m in tool_responses) handle_span.message = f'handle model response -> {response_msgs}' @@ -410,7 +409,7 @@ async def _prepare_messages( async def _handle_model_response( self, model_response: _messages.ModelAnyResponse, deps: AgentDeps - ) -> _utils.Either[ResultData, list[_messages.Message]]: + ) -> _MarkFinalResult[ResultData] | list[_messages.Message]: """Process a non-streamed response from the model. Returns: @@ -424,15 +423,15 @@ async def _handle_model_response( result_data = await self._validate_result(result_data_input, deps, None) except _result.ToolRetryError as e: self._incr_result_retry() - return _utils.Either(right=[e.tool_retry]) + return [e.tool_retry] else: - return _utils.Either(left=result_data) + return _MarkFinalResult(result_data) else: self._incr_result_retry() response = _messages.RetryPrompt( content='Plain text responses are not permitted, please call one of the functions instead.', ) - return _utils.Either(right=[response]) + return [response] elif model_response.role == 'model-structured-response': if self._result_schema is not None: # if there's a result schema, and any of the calls match one of its tools, return the result @@ -444,9 +443,9 @@ async def _handle_model_response( result_data = await self._validate_result(result_data, deps, call) except _result.ToolRetryError as e: self._incr_result_retry() - return _utils.Either(right=[e.tool_retry]) + return [e.tool_retry] else: - return _utils.Either(left=result_data) + return _MarkFinalResult(result_data) if not model_response.calls: raise exceptions.UnexpectedModelBehaviour('Received empty tool call message') @@ -462,22 +461,25 @@ async def _handle_model_response( with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): messages += await asyncio.gather(*tasks) - return _utils.Either(right=messages) + return messages else: assert_never(model_response) async def _handle_streamed_model_response( self, model_response: models.EitherStreamedResponse, deps: AgentDeps - ) -> _utils.Either[models.EitherStreamedResponse, list[_messages.Message]]: + ) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]: """Process a streamed response from the model. + TODO: the response type change to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9 + (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`) + Returns: Return `Either` — left: final result data, right: list of messages to send back to the model. """ if isinstance(model_response, models.StreamTextResponse): # plain string response if self._allow_text_result: - return _utils.Either(left=model_response) + return _MarkFinalResult(model_response) else: self._incr_result_retry() response = _messages.RetryPrompt( @@ -487,7 +489,7 @@ async def _handle_streamed_model_response( async for _ in model_response: pass - return _utils.Either(right=[response]) + return [response] else: assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}' if self._result_schema is not None: @@ -502,7 +504,7 @@ async def _handle_streamed_model_response( structured_msg = model_response.get() if self._result_schema.find_tool(structured_msg): - return _utils.Either(left=model_response) + return _MarkFinalResult(model_response) # the model is calling a retriever function, consume the response to get the next message async for _ in model_response: @@ -522,7 +524,7 @@ async def _handle_streamed_model_response( with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): messages += await asyncio.gather(*tasks) - return _utils.Either(right=messages) + return messages async def _validate_result( self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None @@ -556,3 +558,13 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: else: msg = 'No tools available.' return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}') + + +@dataclass +class _MarkFinalResult(Generic[ResultData]): + """Marker class to indicate that the result is the final result. + + This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly. + """ + + data: ResultData