Skip to content

Commit

Permalink
test structured responses
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 14, 2024
1 parent 862d890 commit 5010820
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import httpx
import pytest
from pydantic import BaseModel

from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
Expand Down Expand Up @@ -45,6 +46,20 @@ async def test_openai_stream(http_client: httpx.AsyncClient):
assert cost.total_tokens is not None and cost.total_tokens > 0


class MyModel(BaseModel):
city: str


async def test_openai_structured(http_client: httpx.AsyncClient):
agent = Agent(OpenAIModel('gpt-4o-mini', http_client=http_client), result_type=MyModel)
result = await agent.run('What is the capital of the UK?')
print('OpenAI structured response:', result.data)
assert result.data.city.lower() == 'london'
print('OpenAI structured cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0


async def test_gemini(http_client: httpx.AsyncClient):
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client))
result = await agent.run('What is the capital of France?')
Expand All @@ -56,11 +71,21 @@ async def test_gemini(http_client: httpx.AsyncClient):


async def test_gemini_stream(http_client: httpx.AsyncClient):
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client))
agent = Agent(GeminiModel('gemini-1.5-pro', http_client=http_client))
async with agent.run_stream('What is the capital of France?') as result:
data = await result.get_data()
print('Gemini stream response:', data)
assert 'paris' in data.lower()
print('Gemini stream cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0


async def test_gemini_structured(http_client: httpx.AsyncClient):
agent = Agent(GeminiModel('gemini-1.5-pro', http_client=http_client), result_type=MyModel)
result = await agent.run('What is the capital of the UK?')
print('Gemini structured response:', result.data)
assert result.data.city.lower() == 'london'
print('Gemini structured cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0

0 comments on commit 5010820

Please sign in to comment.