diff --git a/examples/pydantic_ai_examples/chat_app.py b/examples/pydantic_ai_examples/chat_app.py index 88e136b16..93e7e306c 100644 --- a/examples/pydantic_ai_examples/chat_app.py +++ b/examples/pydantic_ai_examples/chat_app.py @@ -89,11 +89,12 @@ def to_chat_message(m: ModelMessage) -> ChatMessage: first_part = m.parts[0] if isinstance(m, ModelRequest): if isinstance(first_part, UserPromptPart): - return { - 'role': 'user', - 'timestamp': first_part.timestamp.isoformat(), - 'content': first_part.content, - } + if isinstance(first_part.content, str): + return { + 'role': 'user', + 'timestamp': first_part.timestamp.isoformat(), + 'content': first_part.content, + } elif isinstance(m, ModelResponse): if isinstance(first_part, TextPart): return { diff --git a/main.py b/main.py index 1061169d9..351fbc72e 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,18 @@ -import requests +from pathlib import Path + from rich.pretty import pprint from pydantic_ai import Agent -from pydantic_ai.messages import BinaryContent +from pydantic_ai.messages import BinaryContent, ImageUrl -image_path = 'https://goo.gle/instrument-img' -image = requests.get(image_path) +image_url = 'https://goo.gle/instrument-img' agent = Agent(model='google-gla:gemini-2.0-flash-exp') -# data = Path('docs/img/logfire-with-httpx.png').read_bytes() -# data2 = Path('docs/img/tree.png').read_bytes() - -output = agent.run_sync(["What's in the image?", BinaryContent(data=image.content, media_type='image/jpeg')]) +output = agent.run_sync( + [ + "What's in the image?", + ImageUrl(url=image_url), + ] +) pprint(output) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index cacf6be72..80d088557 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -270,10 +270,11 @@ async def _process_streamed_response(self, http_response: HTTPResponse) -> Strea ) if responses: last = responses[-1] - last_content = last['candidates'][0].get('content') - if last['candidates'] and last_content and last_content.parts: - start_response = last - break + if last['candidates']: + last_content = last['candidates'][0].get('content') + if last_content and last_content.parts: + start_response = last + break if start_response is None: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index aa3260be0..f5a1cce74 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -9,6 +9,7 @@ import httpx import pytest +from google.genai.types import Content, FunctionCall, Part from inline_snapshot import snapshot from pydantic import BaseModel, Field from typing_extensions import Literal, TypeAlias @@ -30,16 +31,13 @@ GeminiModel, GeminiModelSettings, _content_model_response, - _function_call_part_from_call, _gemini_response_ta, _gemini_streamed_response_ta, _GeminiCandidates, - _GeminiContent, _GeminiFunction, _GeminiFunctionCallingConfig, _GeminiResponse, _GeminiSafetyRating, - _GeminiTextPart, _GeminiToolConfig, _GeminiTools, _GeminiUsageMetaData, @@ -436,7 +434,7 @@ def handler(request: httpx.Request) -> httpx.Response: return create_client -def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse: +def gemini_response(content: Content, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse: candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[]) if finish_reason: # pragma: no cover candidate['finish_reason'] = finish_reason @@ -796,16 +794,11 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): responses = [ gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))), gemini_response( - _GeminiContent( + Content( role='model', parts=[ - _GeminiTextPart(text='foo'), - _function_call_part_from_call( - ToolCallPart( - tool_name='get_location', - args={'loc_name': 'San Fransisco'}, - ) - ), + Part.from_text(text='foo'), + Part.from_function_call(name='get_location', args={'loc_name': 'San Fransisco'}), ], ) ), @@ -833,13 +826,13 @@ async def test_empty_text_ignored(): ) # text included assert content == snapshot( - { - 'role': 'model', - 'parts': [ - {'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}, - {'text': 'xxx'}, + Content( + parts=[ + Part(function_call=FunctionCall(args={'response': [1, 2, 123]}, name='final_result')), + Part(text='xxx'), ], - } + role='model', + ) ) content = _content_model_response( @@ -852,10 +845,10 @@ async def test_empty_text_ignored(): ) # text skipped assert content == snapshot( - { - 'role': 'model', - 'parts': [{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}], - } + Content( + parts=[Part(function_call=FunctionCall(args={'response': [1, 2, 123]}, name='final_result'))], + role='model', + ) ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index f30a5c14b..8a2e56102 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -112,7 +112,14 @@ async def weather_model(messages: list[ModelMessage], info: AgentInfo) -> ModelR elif last.tool_name == 'get_weather': location_name: str | None = None for m in messages: - location_name = next((part.content for part in m.parts if isinstance(part, UserPromptPart)), None) + location_name = next( + ( + item + for item in (part.content for part in m.parts if isinstance(part, UserPromptPart)) + if isinstance(item, str) + ), + None, + ) if location_name is not None: break @@ -189,7 +196,7 @@ def test_weather(): async def call_function_model(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse: # pragma: no cover last = messages[-1].parts[-1] if isinstance(last, UserPromptPart): - if last.content.startswith('{'): + if isinstance(last.content, str) and last.content.startswith('{'): details = json.loads(last.content) return ModelResponse( parts=[