From fa05579f91e6ecf9444c90fdaea0b66fb1d0dad9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 12 Nov 2024 23:22:34 +0000 Subject: [PATCH] Gemini coverage and other tweaks --- pydantic_ai/agent.py | 3 ++- pyproject.toml | 5 ++++- tests/models/test_gemini.py | 41 +++++++++++++++++++++++++++++++++++++ uv.lock | 29 +++++--------------------- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index a5318c707..93fe5fd17 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -389,6 +389,7 @@ async def _get_agent_model( if model is not None: custom_model = model_ = models.infer_model(model) elif self.model is not None: + # noinspection PyTypeChecker model_ = self.model custom_model = None else: @@ -476,7 +477,7 @@ async def _handle_streamed_model_response( ) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]: """Process a streamed response from the model. - TODO: the response type change to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9 + TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9 (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`) Returns: diff --git a/pyproject.toml b/pyproject.toml index 398726d7f..ae0c96f8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dev = [ "pyright>=1.1.388", "pytest>=8.3.3", "pytest-pretty>=1.2.0", - "inline-snapshot>=0.13.3", + "inline-snapshot>=0.14", "ruff>=0.6.9", "coverage[toml]>=7.6.2", "devtools>=0.12.2", @@ -169,3 +169,6 @@ exclude_lines = [ [tool.logfire] ignore_no_config = true + +[tool.inline-snapshot.shortcuts] +fix=["create", "fix"] diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 045a3436a..9f30b1c91 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -396,6 +396,17 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ) assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) + result = await agent.run('Hello', message_history=result.new_messages()) + assert result.data == 'Hello world' + assert result.all_messages() == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ModelTextResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)), + UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ModelTextResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)), + ] + ) + async def test_request_structured_response(get_gemini_client: GetGeminiClient): response = gemini_response( @@ -646,3 +657,33 @@ async def bar(y: str) -> str: ] ) assert retriever_calls == snapshot(["foo(x='a')", "bar(y='b')"]) + + +async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): + responses = [ + gemini_response(_content_model_text('Hello ')), + gemini_response( + _GeminiContent( + role='model', + parts=[ + _GeminiTextPart(text='foo'), + _function_call_part_from_call( + ToolCall( + tool_name='get_location', + args=ArgsObject(args_object={'loc_name': 'San Fransisco'}), + ) + ), + ], + ) + ), + ] + json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) + stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]]) + gemini_client = get_gemini_client(stream) + m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) + agent = Agent(m) + + msg = 'Streamed response with unexpected content, expected all parts to be text' + async with agent.run_stream('Hello') as result: + with pytest.raises(UnexpectedModelBehaviour, match=msg): + await result.get_data() diff --git a/uv.lock b/uv.lock index 3fe7b05bd..cd48588dc 100644 --- a/uv.lock +++ b/uv.lock @@ -539,7 +539,7 @@ wheels = [ [[package]] name = "inline-snapshot" -version = "0.13.3" +version = "0.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asttokens" }, @@ -547,13 +547,12 @@ dependencies = [ { name = "click" }, { name = "executing" }, { name = "rich" }, - { name = "toml", marker = "python_full_version < '3.11'" }, - { name = "types-toml", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/d8/49d563d434a7a601a0bd35c59ba8d03cd7d2b2d36ffcbc3f50e134d20a95/inline_snapshot-0.13.3.tar.gz", hash = "sha256:de85c5dfd31426c51b2820a3acb4569153fb62414a6a9833378b79859b869763", size = 84106 } +sdist = { url = "https://files.pythonhosted.org/packages/45/a9/b6b9db4f2ef1e3261460701a429f8248e517cb8d18e27ff05f4690ac0a73/inline_snapshot-0.14.0.tar.gz", hash = "sha256:54fdf7831055d06a2423054875d640102865a164cc8291a8086e44dd9b4fd316", size = 209662 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/3a/f878e7bc11160d0eef30076363f0cd01fba454485ac61a6cee34c026cda1/inline_snapshot-0.13.3-py3-none-any.whl", hash = "sha256:b1cf31cea026fcc2abaeb4066950e2a94bc387912323e42752819f0972f12179", size = 31423 }, + { url = "https://files.pythonhosted.org/packages/5e/a3/8ca14974625632d56d7e9f899d76d15dc4acd94ec15c179ca528beadeb4a/inline_snapshot-0.14.0-py3-none-any.whl", hash = "sha256:dc246d28b720f6050404b72cc1d171b0671e1494249197753d23771ff228748c", size = 31807 }, ] [[package]] @@ -1273,7 +1272,7 @@ dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.6.2" }, { name = "devtools", specifier = ">=0.12.2" }, { name = "dirty-equals", specifier = ">=0.8.0" }, - { name = "inline-snapshot", specifier = ">=0.13.3" }, + { name = "inline-snapshot", specifier = ">=0.14" }, { name = "mypy", specifier = ">=1.11.2" }, { name = "pyright", specifier = ">=1.1.388" }, { name = "pytest", specifier = ">=8.3.3" }, @@ -1714,15 +1713,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/43/f185bfd0ca1d213beb4293bed51d92254df23d8ceaf6c0e17146d508a776/starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d", size = 73259 }, ] -[[package]] -name = "toml" -version = "0.10.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 }, -] - [[package]] name = "tomli" version = "2.0.2" @@ -1744,15 +1734,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/5d/acf5905c36149bbaec41ccf7f2b68814647347b72075ac0b1fe3022fdc73/tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd", size = 78351 }, ] -[[package]] -name = "types-toml" -version = "0.10.8.20240310" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/86/47/3e4c75042792bff8e90d7991aa5c51812cc668828cc6cce711e97f63a607/types-toml-0.10.8.20240310.tar.gz", hash = "sha256:3d41501302972436a6b8b239c850b26689657e25281b48ff0ec06345b8830331", size = 4392 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/da/a2/d32ab58c0b216912638b140ab2170ee4b8644067c293b170e19fba340ccc/types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d", size = 4777 }, -] - [[package]] name = "typing-extensions" version = "4.12.2"