Skip to content

Commit

Permalink
WIP: Add support for the use of response_format to force a particular…
Browse files Browse the repository at this point in the history
… json schema for the response
  • Loading branch information
dmontagu committed Feb 22, 2025
1 parent 9b4de86 commit 9c64d68
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 42 deletions.
4 changes: 3 additions & 1 deletion examples/pydantic_ai_examples/flight_booking.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ class Failed(BaseModel):


# This agent is responsible for extracting the user's seat selection
seat_preference_agent = Agent[None, SeatPreference | Failed](
seat_preference_agent = Agent[
None, SeatPreference | Failed
](
'openai:gpt-4o',
result_type=SeatPreference | Failed, # type: ignore
system_prompt=(
Expand Down
12 changes: 6 additions & 6 deletions examples/pydantic_ai_examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ async def retrieve(context: RunContext[Deps], search_query: str) -> str:
model='text-embedding-3-small',
)

assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
)
assert (
len(embedding.data) == 1
), f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
Expand Down Expand Up @@ -149,9 +149,9 @@ async def insert_doc_section(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
)
assert (
len(embedding.data) == 1
), f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
Expand Down
43 changes: 43 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Generic, Literal, Union, cast

import logfire_api
from pydantic import ValidationError
from typing_extensions import TypeVar, assert_never

from pydantic_graph import BaseNode, Graph, GraphRunContext
Expand Down Expand Up @@ -370,12 +371,16 @@ async def _run_stream(

async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts: list[str] = []
structured_outputs: list[str] = []
tool_calls: list[_messages.ToolCallPart] = []
for part in self.model_response.parts:
if isinstance(part, _messages.TextPart):
# ignore empty content for text parts, see #437
if part.content:
texts.append(part.content)
elif isinstance(part, _messages.StructuredOutputPart):
if part.content:
structured_outputs.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
else:
Expand All @@ -391,6 +396,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
elif texts:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_text_response(ctx, texts)
elif structured_outputs:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_structured_outputs_response(ctx, texts)
else:
raise exceptions.UnexpectedModelBehavior('Received empty model response')

Expand Down Expand Up @@ -487,6 +495,41 @@ async def _handle_text_response(
)
)

async def _handle_structured_outputs_response(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
structured_outputs: list[str],
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
if len(structured_outputs) != 1:
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
result_schema = ctx.deps.result_schema
if not result_schema:
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')

structured_output = structured_outputs[0]
try:
result_data_input = result_schema.structured_output_validator.validate_json(structured_output)
except ValidationError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(
parts=[
_messages.RetryPromptPart(
content='Structured output validation failed: ' + str(e),
)
]
)
)

try:
result_data = await _validate_result(result_data_input, ctx, None)
except _result.ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
else:
# The following cast is safe because we know `str` is an allowed result type
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])


def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ResultSchema(Generic[ResultDataT]):
Similar to `Tool` but for the final result of running an agent.
"""

structured_output_validator: TypeAdapter[ResultDataT]
tools: dict[str, ResultTool[ResultDataT]]
allow_text_result: bool

Expand Down
48 changes: 46 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,22 @@ def has_content(self) -> bool:
return bool(self.args)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@dataclass
class StructuredOutputPart:
"""A structured output response from a model."""

content: str
"""The structured content of the response as a JSON-serialized string."""

part_kind: Literal['structured-output'] = 'structured-output'
"""Part type identifier, this is available on all parts as a discriminator."""

def has_content(self) -> bool:
"""Return `True` if the structured content is non-empty."""
return bool(self.content)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, StructuredOutputPart], pydantic.Discriminator('part_kind')]
"""A message part returned by a model."""


Expand Down Expand Up @@ -275,6 +290,33 @@ def apply(self, part: ModelResponsePart) -> TextPart:
return replace(part, content=part.content + self.content_delta)


@dataclass
class StructuredOutputPartDelta:
"""A partial update (delta) for a `StructuredOutputPart` to append new text content."""

content_delta: str
"""The incremental text content to add to the existing `StructuredOutputPart` content."""

part_delta_kind: Literal['structured-output'] = 'structured-output'
"""Part delta type identifier, used as a discriminator."""

def apply(self, part: ModelResponsePart) -> StructuredOutputPart:
"""Apply this text delta to an existing `TextPart`.
Args:
part: The existing model response part, which must be a `TextPart`.
Returns:
A new `TextPart` with updated text content.
Raises:
ValueError: If `part` is not a `TextPart`.
"""
if not isinstance(part, StructuredOutputPart):
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
return replace(part, content=part.content + self.content_delta)


@dataclass
class ToolCallPartDelta:
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
Expand Down Expand Up @@ -408,7 +450,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
return part


ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
ModelResponsePartDelta = Annotated[
Union[TextPartDelta, ToolCallPartDelta, StructuredOutputPartDelta], pydantic.Discriminator('part_delta_kind')
]
"""A partial update (delta) for any model response part."""


Expand Down
34 changes: 31 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import chain
from typing import Literal, Union, cast

from cohere import TextAssistantMessageContentItem
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

Expand All @@ -17,6 +16,7 @@
ModelResponse,
ModelResponsePart,
RetryPromptPart,
StructuredOutputPart,
SystemPromptPart,
TextPart,
ToolCallPart,
Expand All @@ -37,7 +37,11 @@
AsyncClientV2,
ChatMessageV2,
ChatResponse,
JsonObjectResponseFormatV2,
ResponseFormatV2,
SystemChatMessageV2,
TextAssistantMessageContentItem,
TextResponseFormatV2,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
Expand Down Expand Up @@ -152,7 +156,30 @@ async def _chat(
model_settings: CohereModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
if model_settings.get('force_response_format', False):
tools: list[ToolV2] = OMIT
response_format: ResponseFormatV2
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
response_format = TextResponseFormatV2()
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
result_tool = model_request_parameters.result_tools[0]
response_format = JsonObjectResponseFormatV2(
type='json_object',
json_schema=result_tool.parameters_json_schema,
)
else:
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
if model_request_parameters.allow_text_result:
json_schemas.append({'type': 'string'})
response_format = JsonObjectResponseFormatV2(
type='json_object',
json_schema={'anyOf': json_schemas},
)
else:
# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters)
response_format = OMIT

cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
return await self.client.chat(
model=self._model_name,
Expand All @@ -162,6 +189,7 @@ async def _chat(
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
response_format=response_format,
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
Expand Down Expand Up @@ -193,7 +221,7 @@ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
texts: list[str] = []
tool_calls: list[ToolCallV2] = []
for item in message.parts:
if isinstance(item, TextPart):
if isinstance(item, (TextPart, StructuredOutputPart)):
texts.append(item.content)
elif isinstance(item, ToolCallPart):
tool_calls.append(self._map_tool_call(item))
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ async def request_stream(
model_settings,
)

assert self.stream_function is not None, (
'FunctionModel must receive a `stream_function` to support streamed requests'
)
assert (
self.stream_function is not None
), 'FunctionModel must receive a `stream_function` to support streamed requests'

response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))

Expand Down
31 changes: 24 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,40 @@ async def _make_request(
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
sys_prompt_parts, contents = self._message_to_gemini_content(messages)

request_data = _GeminiRequest(contents=contents)
if sys_prompt_parts:
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
if tools is not None:
request_data['tools'] = tools
if tool_config is not None:
request_data['tool_config'] = tool_config

generation_config: _GeminiGenerationConfig = {}
if model_settings.get('force_response_format', False):
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
generation_config['response_mimetype'] = 'text/plain'
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
generation_config['response_mimetype'] = 'application/json'
generation_config['response_schema'] = model_request_parameters.result_tools[0].parameters_json_schema
else:
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
if model_request_parameters.allow_text_result:
json_schemas.append({'type': 'string'})
generation_config['response_schema'] = {'anyOf': json_schemas}
else:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
if tools is not None:
request_data['tools'] = tools
if tool_config is not None:
request_data['tool_config'] = tool_config

if model_settings:
if (max_tokens := model_settings.get('max_tokens')) is not None:
generation_config['max_output_tokens'] = max_tokens
if (temperature := model_settings.get('temperature')) is not None:
generation_config['temperature'] = temperature
if (top_p := model_settings.get('top_p')) is not None:
generation_config['top_p'] = top_p
if (seed := model_settings.get('seed')) is not None:
generation_config['seed'] = seed
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
generation_config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
Expand Down Expand Up @@ -465,9 +479,12 @@ class _GeminiGenerationConfig(TypedDict, total=False):
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
"""

response_mimetype: Literal['text/plain', 'application/json']
response_schema: dict[str, Any]
max_output_tokens: int
temperature: float
top_p: float
seed: int
presence_penalty: float
frequency_penalty: float

Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ async def _completions_create(
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
response_format=response_format,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
Expand Down
Loading

0 comments on commit 9c64d68

Please sign in to comment.