Skip to content

Commit

Permalink
remove Either from _handle_model_response and `_handle_streamed_m…
Browse files Browse the repository at this point in the history
…odel_response` (#39)
  • Loading branch information
samuelcolvin authored Nov 12, 2024
1 parent 5882769 commit 4c03c17
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ async def run(
with _logfire.span('handle model response') as handle_span:
either = await self._handle_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
if isinstance(either, _MarkFinalResult):
# we have a final result, end the conversation
result_data = either.data
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.set_attribute('result', result_data)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, left.value, cost)
return result.RunResult(messages, new_message_index, result_data, cost)
else:
# right means continue the conversation
tool_responses = either.right
# continue the conversation
tool_responses = either
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
Expand Down Expand Up @@ -256,9 +257,8 @@ async def run_stream(
with _logfire.span('handle model response') as handle_span:
either = await self._handle_streamed_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
result_stream = left.value
if isinstance(either, _MarkFinalResult):
result_stream = either.data
run_span.set_attribute('all_messages', messages)
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
handle_span.message = 'handle model response -> final result'
Expand All @@ -273,8 +273,7 @@ async def run_stream(
)
return
else:
# right means continue the conversation
tool_responses = either.right
tool_responses = either
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
Expand Down Expand Up @@ -410,7 +409,7 @@ async def _prepare_messages(

async def _handle_model_response(
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
) -> _utils.Either[ResultData, list[_messages.Message]]:
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
"""Process a non-streamed response from the model.
Returns:
Expand All @@ -424,15 +423,15 @@ async def _handle_model_response(
result_data = await self._validate_result(result_data_input, deps, None)
except _result.ToolRetryError as e:
self._incr_result_retry()
return _utils.Either(right=[e.tool_retry])
return [e.tool_retry]
else:
return _utils.Either(left=result_data)
return _MarkFinalResult(result_data)
else:
self._incr_result_retry()
response = _messages.RetryPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return _utils.Either(right=[response])
return [response]
elif model_response.role == 'model-structured-response':
if self._result_schema is not None:
# if there's a result schema, and any of the calls match one of its tools, return the result
Expand All @@ -444,9 +443,9 @@ async def _handle_model_response(
result_data = await self._validate_result(result_data, deps, call)
except _result.ToolRetryError as e:
self._incr_result_retry()
return _utils.Either(right=[e.tool_retry])
return [e.tool_retry]
else:
return _utils.Either(left=result_data)
return _MarkFinalResult(result_data)

if not model_response.calls:
raise exceptions.UnexpectedModelBehaviour('Received empty tool call message')
Expand All @@ -462,22 +461,25 @@ async def _handle_model_response(

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)
return messages
else:
assert_never(model_response)

async def _handle_streamed_model_response(
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
) -> _utils.Either[models.EitherStreamedResponse, list[_messages.Message]]:
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
"""Process a streamed response from the model.
TODO: the response type change to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
Returns:
Return `Either` — left: final result data, right: list of messages to send back to the model.
"""
if isinstance(model_response, models.StreamTextResponse):
# plain string response
if self._allow_text_result:
return _utils.Either(left=model_response)
return _MarkFinalResult(model_response)
else:
self._incr_result_retry()
response = _messages.RetryPrompt(
Expand All @@ -487,7 +489,7 @@ async def _handle_streamed_model_response(
async for _ in model_response:
pass

return _utils.Either(right=[response])
return [response]
else:
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
if self._result_schema is not None:
Expand All @@ -502,7 +504,7 @@ async def _handle_streamed_model_response(
structured_msg = model_response.get()

if self._result_schema.find_tool(structured_msg):
return _utils.Either(left=model_response)
return _MarkFinalResult(model_response)

# the model is calling a retriever function, consume the response to get the next message
async for _ in model_response:
Expand All @@ -522,7 +524,7 @@ async def _handle_streamed_model_response(

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)
return messages

async def _validate_result(
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
Expand Down Expand Up @@ -556,3 +558,13 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
else:
msg = 'No tools available.'
return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')


@dataclass
class _MarkFinalResult(Generic[ResultData]):
"""Marker class to indicate that the result is the final result.
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
"""

data: ResultData

0 comments on commit 4c03c17

Please sign in to comment.