Skip to content

Commit

Permalink
Gemini coverage and other tweaks (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 12, 2024
1 parent a8edf02 commit d3d8eef
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 26 deletions.
3 changes: 2 additions & 1 deletion pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ async def _get_agent_model(
if model is not None:
custom_model = model_ = models.infer_model(model)
elif self.model is not None:
# noinspection PyTypeChecker
model_ = self.model
custom_model = None
else:
Expand Down Expand Up @@ -476,7 +477,7 @@ async def _handle_streamed_model_response(
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
"""Process a streamed response from the model.
TODO: the response type change to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
Returns:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = [
"pyright>=1.1.388",
"pytest>=8.3.3",
"pytest-pretty>=1.2.0",
"inline-snapshot>=0.13.3",
"inline-snapshot>=0.14",
"ruff>=0.6.9",
"coverage[toml]>=7.6.2",
"devtools>=0.12.2",
Expand Down Expand Up @@ -169,3 +169,6 @@ exclude_lines = [

[tool.logfire]
ignore_no_config = true

[tool.inline-snapshot.shortcuts]
fix=["create", "fix"]
41 changes: 41 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,17 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
)
assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3))

result = await agent.run('Hello', message_history=result.new_messages())
assert result.data == 'Hello world'
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)


async def test_request_structured_response(get_gemini_client: GetGeminiClient):
response = gemini_response(
Expand Down Expand Up @@ -646,3 +657,33 @@ async def bar(y: str) -> str:
]
)
assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"])


async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
responses = [
gemini_response(_content_model_text('Hello ')),
gemini_response(
_GeminiContent(
role='model',
parts=[
_GeminiTextPart(text='foo'),
_function_call_part_from_call(
ToolCall(
tool_name='get_location',
args=ArgsObject(args_object={'loc_name': 'San Fransisco'}),
)
),
],
)
),
]
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
gemini_client = get_gemini_client(stream)
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
agent = Agent(m)

msg = 'Streamed response with unexpected content, expected all parts to be text'
async with agent.run_stream('Hello') as result:
with pytest.raises(UnexpectedModelBehaviour, match=msg):
await result.get_data()
29 changes: 5 additions & 24 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d3d8eef

Please sign in to comment.