Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Feb 21, 2025
1 parent 1dfd9b1 commit 10d6b08
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 41 deletions.
11 changes: 6 additions & 5 deletions examples/pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def to_chat_message(m: ModelMessage) -> ChatMessage:
first_part = m.parts[0]
if isinstance(m, ModelRequest):
if isinstance(first_part, UserPromptPart):
return {
'role': 'user',
'timestamp': first_part.timestamp.isoformat(),
'content': first_part.content,
}
if isinstance(first_part.content, str):
return {
'role': 'user',
'timestamp': first_part.timestamp.isoformat(),
'content': first_part.content,
}
elif isinstance(m, ModelResponse):
if isinstance(first_part, TextPart):
return {
Expand Down
18 changes: 10 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import requests
from pathlib import Path

from rich.pretty import pprint

from pydantic_ai import Agent
from pydantic_ai.messages import BinaryContent
from pydantic_ai.messages import BinaryContent, ImageUrl

image_path = 'https://goo.gle/instrument-img'
image = requests.get(image_path)
image_url = 'https://goo.gle/instrument-img'

agent = Agent(model='google-gla:gemini-2.0-flash-exp')

# data = Path('docs/img/logfire-with-httpx.png').read_bytes()
# data2 = Path('docs/img/tree.png').read_bytes()

output = agent.run_sync(["What's in the image?", BinaryContent(data=image.content, media_type='image/jpeg')])
output = agent.run_sync(
[
"What's in the image?",
ImageUrl(url=image_url),
]
)
pprint(output)
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,11 @@ async def _process_streamed_response(self, http_response: HTTPResponse) -> Strea
)
if responses:
last = responses[-1]
last_content = last['candidates'][0].get('content')
if last['candidates'] and last_content and last_content.parts:
start_response = last
break
if last['candidates']:
last_content = last['candidates'][0].get('content')
if last_content and last_content.parts:
start_response = last
break

if start_response is None:
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
Expand Down
37 changes: 15 additions & 22 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import httpx
import pytest
from google.genai.types import Content, FunctionCall, Part
from inline_snapshot import snapshot
from pydantic import BaseModel, Field
from typing_extensions import Literal, TypeAlias
Expand All @@ -30,16 +31,13 @@
GeminiModel,
GeminiModelSettings,
_content_model_response,
_function_call_part_from_call,
_gemini_response_ta,
_gemini_streamed_response_ta,
_GeminiCandidates,
_GeminiContent,
_GeminiFunction,
_GeminiFunctionCallingConfig,
_GeminiResponse,
_GeminiSafetyRating,
_GeminiTextPart,
_GeminiToolConfig,
_GeminiTools,
_GeminiUsageMetaData,
Expand Down Expand Up @@ -436,7 +434,7 @@ def handler(request: httpx.Request) -> httpx.Response:
return create_client


def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse:
def gemini_response(content: Content, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse:
candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[])
if finish_reason: # pragma: no cover
candidate['finish_reason'] = finish_reason
Expand Down Expand Up @@ -796,16 +794,11 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
responses = [
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))),
gemini_response(
_GeminiContent(
Content(
role='model',
parts=[
_GeminiTextPart(text='foo'),
_function_call_part_from_call(
ToolCallPart(
tool_name='get_location',
args={'loc_name': 'San Fransisco'},
)
),
Part.from_text(text='foo'),
Part.from_function_call(name='get_location', args={'loc_name': 'San Fransisco'}),
],
)
),
Expand Down Expand Up @@ -833,13 +826,13 @@ async def test_empty_text_ignored():
)
# text included
assert content == snapshot(
{
'role': 'model',
'parts': [
{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}},
{'text': 'xxx'},
Content(
parts=[
Part(function_call=FunctionCall(args={'response': [1, 2, 123]}, name='final_result')),
Part(text='xxx'),
],
}
role='model',
)
)

content = _content_model_response(
Expand All @@ -852,10 +845,10 @@ async def test_empty_text_ignored():
)
# text skipped
assert content == snapshot(
{
'role': 'model',
'parts': [{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}],
}
Content(
parts=[Part(function_call=FunctionCall(args={'response': [1, 2, 123]}, name='final_result'))],
role='model',
)
)


Expand Down
11 changes: 9 additions & 2 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ async def weather_model(messages: list[ModelMessage], info: AgentInfo) -> ModelR
elif last.tool_name == 'get_weather':
location_name: str | None = None
for m in messages:
location_name = next((part.content for part in m.parts if isinstance(part, UserPromptPart)), None)
location_name = next(
(
item
for item in (part.content for part in m.parts if isinstance(part, UserPromptPart))
if isinstance(item, str)
),
None,
)
if location_name is not None:
break

Expand Down Expand Up @@ -189,7 +196,7 @@ def test_weather():
async def call_function_model(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse: # pragma: no cover
last = messages[-1].parts[-1]
if isinstance(last, UserPromptPart):
if last.content.startswith('{'):
if isinstance(last.content, str) and last.content.startswith('{'):
details = json.loads(last.content)
return ModelResponse(
parts=[
Expand Down

0 comments on commit 10d6b08

Please sign in to comment.