Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove Either from _handle_model_response and _handle_streamed_model_response #39

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ async def run(
with _logfire.span('handle model response') as handle_span:
either = await self._handle_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
if isinstance(either, _MarkFinalResult):
# we have a final result, end the conversation
result_data = either.data
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.set_attribute('result', result_data)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, left.value, cost)
return result.RunResult(messages, new_message_index, result_data, cost)
else:
# right means continue the conversation
tool_responses = either.right
# continue the conversation
tool_responses = either
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
Expand Down Expand Up @@ -256,9 +257,8 @@ async def run_stream(
with _logfire.span('handle model response') as handle_span:
either = await self._handle_streamed_model_response(model_response, deps)

if left := either.left:
# left means return a streamed result
result_stream = left.value
if isinstance(either, _MarkFinalResult):
result_stream = either.data
run_span.set_attribute('all_messages', messages)
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
handle_span.message = 'handle model response -> final result'
Expand All @@ -273,8 +273,7 @@ async def run_stream(
)
return
else:
# right means continue the conversation
tool_responses = either.right
tool_responses = either
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
Expand Down Expand Up @@ -410,7 +409,7 @@ async def _prepare_messages(

async def _handle_model_response(
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
) -> _utils.Either[ResultData, list[_messages.Message]]:
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
"""Process a non-streamed response from the model.

Returns:
Expand All @@ -424,15 +423,15 @@ async def _handle_model_response(
result_data = await self._validate_result(result_data_input, deps, None)
except _result.ToolRetryError as e:
self._incr_result_retry()
return _utils.Either(right=[e.tool_retry])
return [e.tool_retry]
else:
return _utils.Either(left=result_data)
return _MarkFinalResult(result_data)
else:
self._incr_result_retry()
response = _messages.RetryPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return _utils.Either(right=[response])
return [response]
elif model_response.role == 'model-structured-response':
if self._result_schema is not None:
# if there's a result schema, and any of the calls match one of its tools, return the result
Expand All @@ -444,9 +443,9 @@ async def _handle_model_response(
result_data = await self._validate_result(result_data, deps, call)
except _result.ToolRetryError as e:
self._incr_result_retry()
return _utils.Either(right=[e.tool_retry])
return [e.tool_retry]
else:
return _utils.Either(left=result_data)
return _MarkFinalResult(result_data)

if not model_response.calls:
raise exceptions.UnexpectedModelBehaviour('Received empty tool call message')
Expand All @@ -462,22 +461,25 @@ async def _handle_model_response(

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)
return messages
else:
assert_never(model_response)

async def _handle_streamed_model_response(
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
) -> _utils.Either[models.EitherStreamedResponse, list[_messages.Message]]:
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
"""Process a streamed response from the model.

TODO: the response type change to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)

Returns:
Return `Either` — left: final result data, right: list of messages to send back to the model.
"""
if isinstance(model_response, models.StreamTextResponse):
# plain string response
if self._allow_text_result:
return _utils.Either(left=model_response)
return _MarkFinalResult(model_response)
else:
self._incr_result_retry()
response = _messages.RetryPrompt(
Expand All @@ -487,7 +489,7 @@ async def _handle_streamed_model_response(
async for _ in model_response:
pass

return _utils.Either(right=[response])
return [response]
else:
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
if self._result_schema is not None:
Expand All @@ -502,7 +504,7 @@ async def _handle_streamed_model_response(
structured_msg = model_response.get()

if self._result_schema.find_tool(structured_msg):
return _utils.Either(left=model_response)
return _MarkFinalResult(model_response)

# the model is calling a retriever function, consume the response to get the next message
async for _ in model_response:
Expand All @@ -522,7 +524,7 @@ async def _handle_streamed_model_response(

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
messages += await asyncio.gather(*tasks)
return _utils.Either(right=messages)
return messages

async def _validate_result(
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
Expand Down Expand Up @@ -556,3 +558,13 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
else:
msg = 'No tools available.'
return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')


@dataclass
class _MarkFinalResult(Generic[ResultData]):
"""Marker class to indicate that the result is the final result.

This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
"""

data: ResultData
Loading