Skip to content

Commit

Permalink
testing the OpenAI model (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 19, 2024
1 parent 17bd167 commit d73c40c
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 39 deletions.
10 changes: 4 additions & 6 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Literal, assert_never
from typing import Literal

from httpx import AsyncClient as AsyncHTTPClient
from openai import AsyncOpenAI
from openai.types import ChatModel, chat
from typing_extensions import assert_never

from .. import shared
from ..messages import (
ArgsJson,
LLMMessage,
Expand All @@ -36,8 +36,6 @@ def __init__(
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
):
if model_name not in ChatModel.__args__:
raise shared.UserError(f'Invalid model name: {model_name}')
self.model_name: ChatModel = model_name
if openai_client is not None:
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
Expand Down Expand Up @@ -144,7 +142,7 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam:
# LLMToolCalls ->
return chat.ChatCompletionAssistantMessageParam(
role='assistant',
tool_calls=[_guard_tool_call(t) for t in message.calls],
tool_calls=[_map_tool_call(t) for t in message.calls],
)
elif message.role == 'plain-response-forbidden':
# PlainResponseForbidden ->
Expand All @@ -162,7 +160,7 @@ def _guard_tool_id(t: ToolCall | ToolReturn | ToolRetry) -> str:
return t.tool_id


def _guard_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
return chat.ChatCompletionMessageToolCallParam(
id=_guard_tool_id(t),
Expand Down
24 changes: 23 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable

import httpx
import pytest
from typing_extensions import TypeAlias

__all__ = 'IsNow', 'TestEnv'

Expand Down Expand Up @@ -44,3 +46,23 @@ def env():
@pytest.fixture
def anyio_backend():
return 'asyncio'


@pytest.fixture
async def client_with_handler():
client: httpx.AsyncClient | None = None

def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient:
nonlocal client
assert client is None, 'client_with_handler can only be called once'
client = httpx.AsyncClient(mounts={'all://': httpx.MockTransport(handler)})
return client

try:
yield create_client
finally:
if client: # pragma: no cover
await client.aclose()


ClientWithHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.AsyncClient]
46 changes: 14 additions & 32 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

import json
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from dataclasses import dataclass

import httpx
Expand Down Expand Up @@ -36,7 +36,7 @@
_GeminiTools, # pyright: ignore[reportPrivateUsage]
_GeminiUsageMetaData, # pyright: ignore[reportPrivateUsage]
)
from tests.conftest import IsNow, TestEnv
from tests.conftest import ClientWithHandler, IsNow, TestEnv

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -317,30 +317,10 @@ class Location(BaseModel):


@pytest.fixture
async def get_gemini_client_handler(env: TestEnv):
async def get_gemini_client(client_with_handler: ClientWithHandler, env: TestEnv):
env.set('GEMINI_API_KEY', 'via-env-var')

client: httpx.AsyncClient | None = None

async def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient:
nonlocal client
assert client is None, 'get_gemini_client can only be called once'
client = httpx.AsyncClient(mounts={'all://': httpx.MockTransport(handler)})
return client

try:
yield create_client
finally:
if client: # pragma: no cover
await client.aclose()


GetGeminiClientHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], Awaitable[httpx.AsyncClient]]


@pytest.fixture
async def get_gemini_client(get_gemini_client_handler: GetGeminiClientHandler):
async def create_client(response_data: _GeminiResponse | list[_GeminiResponse]) -> httpx.AsyncClient:
def create_client(response_data: _GeminiResponse | list[_GeminiResponse]) -> httpx.AsyncClient:
index = 0

def handler(_request: httpx.Request) -> httpx.Response:
Expand All @@ -355,12 +335,12 @@ def handler(_request: httpx.Request) -> httpx.Response:
content = _gemini_response_ta.dump_json(r, by_alias=True)
return httpx.Response(200, content=content, headers={'Content-Type': 'application/json'})

return await get_gemini_client_handler(handler)
return client_with_handler(handler)

return create_client


GetGeminiClient: TypeAlias = 'Callable[[_GeminiResponse | list[_GeminiResponse]], Awaitable[httpx.AsyncClient]]'
GetGeminiClient: TypeAlias = 'Callable[[_GeminiResponse | list[_GeminiResponse]], httpx.AsyncClient]'


def gemini_response(content: _GeminiContent) -> _GeminiResponse:
Expand All @@ -372,7 +352,7 @@ def gemini_response(content: _GeminiContent) -> _GeminiResponse:

async def test_request_simple_success(get_gemini_client: GetGeminiClient):
response = gemini_response(_GeminiContent.model_text('Hello world'))
gemini_client = await get_gemini_client(response)
gemini_client = get_gemini_client(response)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None)

Expand All @@ -386,7 +366,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
LLMToolCalls(calls=[ToolCall.from_object('final_result', {'response': [1, 2, 123]})])
)
)
gemini_client = await get_gemini_client(response)
gemini_client = get_gemini_client(response)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None, result_type=list[int])

Expand Down Expand Up @@ -422,7 +402,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
),
gemini_response(_GeminiContent.model_text('final response')),
]
gemini_client = await get_gemini_client(responses)
gemini_client = get_gemini_client(responses)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None, system_prompt='this is the system prompt')

Expand Down Expand Up @@ -464,11 +444,13 @@ async def get_location(loc_name: str) -> str:
)


async def test_unexpected_response(get_gemini_client_handler: GetGeminiClientHandler):
async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv):
env.set('GEMINI_API_KEY', 'via-env-var')

def handler(_: httpx.Request):
return httpx.Response(401, content='invalid request')

gemini_client = await get_gemini_client_handler(handler)
gemini_client = client_with_handler(handler)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None, system_prompt='this is the system prompt')

Expand Down Expand Up @@ -500,7 +482,7 @@ async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
],
)
)
gemini_client = await get_gemini_client(response)
gemini_client = get_gemini_client(response)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m, deps=None)
with pytest.raises(AgentError, match='Error while running model gemini-1.5-flash') as exc_info:
Expand Down
192 changes: 192 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from __future__ import annotations as _annotations

import datetime
import json
from typing import Any, cast

import pytest
from inline_snapshot import snapshot
from openai import AsyncOpenAI
from openai.types import chat
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # pyright: ignore[reportPrivateImportUsage]
from openai.types.chat.chat_completion_message_tool_call import Function

from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import (
ArgsJson,
LLMResponse,
LLMToolCalls,
SystemPrompt,
ToolCall,
ToolRetry,
ToolReturn,
UserPrompt,
)
from pydantic_ai.models.openai import OpenAIModel
from tests.conftest import IsNow

pytestmark = pytest.mark.anyio


def test_init():
m = OpenAIModel('gpt-4', api_key='foobar')
assert m.client.api_key == 'foobar'
assert m.name() == 'openai:gpt-4'


class MockOpenAI:
def __init__(self, completions: chat.ChatCompletion | list[chat.ChatCompletion]):
self.completions = completions
self.index = 0
chat_completions = type('Completions', (), {'create': self.chat_completions_create})
self.chat = type('Chat', (), {'completions': chat_completions})

@classmethod
def create_mock(cls, completions: chat.ChatCompletion | list[chat.ChatCompletion]) -> AsyncOpenAI:
return cast(AsyncOpenAI, cls(completions))

async def chat_completions_create(self, *_args: Any, **_kwargs: Any) -> chat.ChatCompletion:
if isinstance(self.completions, list):
completion = self.completions[self.index]
else:
completion = self.completions
self.index += 1
return completion


def completion_message(message: ChatCompletionMessage) -> chat.ChatCompletion:
return chat.ChatCompletion(
id='123',
choices=[Choice(finish_reason='stop', index=0, message=message)],
created=1704067200, # 2024-01-01
model='gpt-4',
object='chat.completion',
)


async def test_request_simple_success():
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4', openai_client=mock_client)
agent = Agent(m, deps=None)

result = await agent.run('Hello')
assert result.response == 'world'


async def test_request_structured_response():
c = completion_message(
ChatCompletionMessage(
content=None,
role='assistant',
tool_calls=[
chat.ChatCompletionMessageToolCall(
id='123',
function=Function(arguments='{"response": [1, 2, 123]}', name='final_result'),
type='function',
)
],
)
)
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4', openai_client=mock_client)
agent = Agent(m, deps=None, result_type=list[int])

result = await agent.run('Hello')
assert result.response == [1, 2, 123]
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
LLMToolCalls(
calls=[
ToolCall(
tool_name='final_result',
args=ArgsJson(args_json='{"response": [1, 2, 123]}'),
tool_id='123',
)
],
timestamp=datetime.datetime(2024, 1, 1),
),
]
)


async def test_request_tool_call():
responses = [
completion_message(
ChatCompletionMessage(
content=None,
role='assistant',
tool_calls=[
chat.ChatCompletionMessageToolCall(
id='1',
function=Function(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
)
),
completion_message(
ChatCompletionMessage(
content=None,
role='assistant',
tool_calls=[
chat.ChatCompletionMessageToolCall(
id='2',
function=Function(arguments='{"loc_name": "London"}', name='get_location'),
type='function',
)
],
)
),
completion_message(ChatCompletionMessage(content='final response', role='assistant')),
]
mock_client = MockOpenAI.create_mock(responses)
m = OpenAIModel('gpt-4', openai_client=mock_client)
agent = Agent(m, deps=None, system_prompt='this is the system prompt')

@agent.retriever_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
else:
raise ModelRetry('Wrong location, please try again')

result = await agent.run('Hello')
assert result.response == 'final response'
assert result.message_history == snapshot(
[
SystemPrompt(content='this is the system prompt'),
UserPrompt(content='Hello', timestamp=IsNow()),
LLMToolCalls(
calls=[
ToolCall(
tool_name='get_location',
args=ArgsJson(args_json='{"loc_name": "San Fransisco"}'),
tool_id='1',
)
],
timestamp=datetime.datetime(2024, 1, 1, 0, 0),
),
ToolRetry(
tool_name='get_location', content='Wrong location, please try again', tool_id='1', timestamp=IsNow()
),
LLMToolCalls(
calls=[
ToolCall(
tool_name='get_location',
args=ArgsJson(args_json='{"loc_name": "London"}'),
tool_id='2',
)
],
timestamp=datetime.datetime(2024, 1, 1, 0, 0),
),
ToolReturn(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_id='2',
timestamp=IsNow(),
),
LLMResponse(content='final response', timestamp=datetime.datetime(2024, 1, 1, 0, 0)),
]
)

0 comments on commit d73c40c

Please sign in to comment.