-
Notifications
You must be signed in to change notification settings - Fork 527
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6dc3e8d
commit 862d890
Showing
4 changed files
with
89 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Tests of pydantic-ai actually connecting to OpenAI and Gemini models. | ||
WARNING: running these tests will consume your OpenAI and Gemini credits. | ||
""" | ||
|
||
import os | ||
|
||
import httpx | ||
import pytest | ||
|
||
from pydantic_ai import Agent | ||
from pydantic_ai.models.gemini import GeminiModel | ||
from pydantic_ai.models.openai import OpenAIModel | ||
|
||
pytestmark = [ | ||
pytest.mark.skipif(os.getenv('PYDANTIC_AI_LIVE_TEST_DANGEROUS') != 'CHARGE-ME!', reason='live tests disabled'), | ||
pytest.mark.anyio, | ||
] | ||
|
||
|
||
@pytest.fixture | ||
async def http_client(): | ||
async with httpx.AsyncClient(timeout=30) as client: | ||
yield client | ||
|
||
|
||
async def test_openai(http_client: httpx.AsyncClient): | ||
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client)) | ||
result = await agent.run('What is the capital of France?') | ||
print('OpenAI response:', result.data) | ||
assert 'paris' in result.data.lower() | ||
print('OpenAI cost:', result.cost()) | ||
cost = result.cost() | ||
assert cost.total_tokens is not None and cost.total_tokens > 0 | ||
|
||
|
||
async def test_openai_stream(http_client: httpx.AsyncClient): | ||
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client)) | ||
async with agent.run_stream('What is the capital of France?') as result: | ||
data = await result.get_data() | ||
print('OpenAI stream response:', data) | ||
assert 'paris' in data.lower() | ||
print('OpenAI stream cost:', result.cost()) | ||
cost = result.cost() | ||
assert cost.total_tokens is not None and cost.total_tokens > 0 | ||
|
||
|
||
async def test_gemini(http_client: httpx.AsyncClient): | ||
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client)) | ||
result = await agent.run('What is the capital of France?') | ||
print('Gemini response:', result.data) | ||
assert 'paris' in result.data.lower() | ||
print('Gemini cost:', result.cost()) | ||
cost = result.cost() | ||
assert cost.total_tokens is not None and cost.total_tokens > 0 | ||
|
||
|
||
async def test_gemini_stream(http_client: httpx.AsyncClient): | ||
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client)) | ||
async with agent.run_stream('What is the capital of France?') as result: | ||
data = await result.get_data() | ||
print('Gemini stream response:', data) | ||
assert 'paris' in data.lower() | ||
print('Gemini stream cost:', result.cost()) | ||
cost = result.cost() | ||
assert cost.total_tokens is not None and cost.total_tokens > 0 |