Skip to content

Commit

Permalink
test against real models
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 14, 2024
1 parent 6dc3e8d commit 862d890
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
26 changes: 21 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ jobs:

- run: tree site

test-live:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- run: uv run --python 3.12 --frozen pytest tests/test_live.py
if: github.repository_owner == 'pydantic'
env:
PYDANTIC_AI_LIVE_TEST_DANGEROUS: 'CHARGE-ME!'
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}

test:
name: test on ${{ matrix.python-version }}
runs-on: ubuntu-latest
Expand All @@ -64,7 +80,7 @@ jobs:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
env:
PYTHON: ${{ matrix.python-version }}
UV_PYTHON: ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v4

Expand All @@ -73,15 +89,15 @@ jobs:
enable-cache: true

- run: mkdir coverage
- run: uv run --frozen --python ${{ matrix.python-version }} coverage run -m pytest
- run: uv run --frozen coverage run -m pytest
env:
COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}

- run: uv run --frozen --all-extras --python ${{ matrix.python-version }} coverage run -m pytest
- run: uv run --frozen --all-extras coverage run -m pytest
env:
COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-all-extras

- run: uv run --frozen --all-extras --python ${{ matrix.python-version }} python tests/import_examples.py
- run: uv run --frozen --all-extras python tests/import_examples.py

- name: store coverage files
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -117,7 +133,7 @@ jobs:
# https://github.com/marketplace/actions/alls-green#why used for branch protection checks
check:
if: always()
needs: [lint, docs, test, coverage]
needs: [lint, docs, test-live, test, coverage]
runs-on: ubuntu-latest

steps:
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncI

async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
if r.status_code != 200:
await r.aread()
raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text)
yield r

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ filterwarnings = [
[tool.coverage.run]
# required to avoid warnings about files created by create_module fixture
include = ["pydantic_ai/**/*.py", "tests/**/*.py"]
omit = ["tests/test_live.py"]
branch = true

# https://coverage.readthedocs.io/en/latest/config.html#report
Expand Down
66 changes: 66 additions & 0 deletions tests/test_live.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Tests of pydantic-ai actually connecting to OpenAI and Gemini models.
WARNING: running these tests will consume your OpenAI and Gemini credits.
"""

import os

import httpx
import pytest

from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.models.openai import OpenAIModel

pytestmark = [
pytest.mark.skipif(os.getenv('PYDANTIC_AI_LIVE_TEST_DANGEROUS') != 'CHARGE-ME!', reason='live tests disabled'),
pytest.mark.anyio,
]


@pytest.fixture
async def http_client():
async with httpx.AsyncClient(timeout=30) as client:
yield client


async def test_openai(http_client: httpx.AsyncClient):
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client))
result = await agent.run('What is the capital of France?')
print('OpenAI response:', result.data)
assert 'paris' in result.data.lower()
print('OpenAI cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0


async def test_openai_stream(http_client: httpx.AsyncClient):
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client))
async with agent.run_stream('What is the capital of France?') as result:
data = await result.get_data()
print('OpenAI stream response:', data)
assert 'paris' in data.lower()
print('OpenAI stream cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0


async def test_gemini(http_client: httpx.AsyncClient):
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client))
result = await agent.run('What is the capital of France?')
print('Gemini response:', result.data)
assert 'paris' in result.data.lower()
print('Gemini cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0


async def test_gemini_stream(http_client: httpx.AsyncClient):
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client))
async with agent.run_stream('What is the capital of France?') as result:
data = await result.get_data()
print('Gemini stream response:', data)
assert 'paris' in data.lower()
print('Gemini stream cost:', result.cost())
cost = result.cost()
assert cost.total_tokens is not None and cost.total_tokens > 0

0 comments on commit 862d890

Please sign in to comment.