From a8edf02ab9e35b449a59bd41d4c4e319606c1987 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 12 Nov 2024 20:13:01 +0000 Subject: [PATCH] allow overriding deps, e.g. in testing (#41) --- docs/api/agent.md | 1 + pydantic_ai/agent.py | 42 ++++++++++++++++++++++++++++++++++-------- tests/test_deps.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/typed_agent.py | 29 +++++++++++++++++++---------- 4 files changed, 93 insertions(+), 18 deletions(-) create mode 100644 tests/test_deps.py diff --git a/docs/api/agent.md b/docs/api/agent.md index fd6ae8cac..fe2f6e02b 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -8,6 +8,7 @@ - run_sync - run_stream - model + - override_deps - last_run_messages - system_prompt - retriever_plain diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index e6a6aa27d..a5318c707 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -1,8 +1,8 @@ from __future__ import annotations as _annotations import asyncio -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -52,6 +52,7 @@ class Agent(Generic[AgentDeps, ResultData]): _deps_type: type[AgentDeps] _max_result_retries: int _current_result_retry: int + _override_deps_stack: list[AgentDeps] last_run_messages: list[_messages.Message] | None = None """The messages from the last run, useful when a run raised an exception. @@ -103,6 +104,7 @@ def __init__( self._max_result_retries = result_retries if result_retries is not None else retries self._current_result_retry = 0 self._result_validators = [] + self._override_deps_stack = [] async def run( self, @@ -110,7 +112,7 @@ async def run( *, message_history: list[_messages.Message] | None = None, model: models.Model | KnownModelName | None = None, - deps: AgentDeps | None = None, + deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -125,8 +127,7 @@ async def run( """ model_used, custom_model, agent_model = await self._get_agent_model(model) - # we could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope - deps = cast(AgentDeps, deps) + deps = self._get_deps(deps) new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages @@ -180,7 +181,7 @@ def run_sync( *, message_history: list[_messages.Message] | None = None, model: models.Model | KnownModelName | None = None, - deps: AgentDeps | None = None, + deps: AgentDeps = None, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -204,7 +205,7 @@ async def run_stream( *, message_history: list[_messages.Message] | None = None, model: models.Model | KnownModelName | None = None, - deps: AgentDeps | None = None, + deps: AgentDeps = None, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -219,7 +220,7 @@ async def run_stream( """ model_used, custom_model, agent_model = await self._get_agent_model(model) - deps = cast(AgentDeps, deps) + deps = self._get_deps(deps) new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history) self.last_run_messages = messages @@ -273,6 +274,19 @@ async def run_stream( # the model_response should have been fully streamed by now, we can add it's cost cost += model_response.cost() + @contextmanager + def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, this is particularly useful when testing. + + Args: + overriding_deps: The dependencies to use instead of the dependencies passed to the agent run. + """ + self._override_deps_stack.append(overriding_deps) + try: + yield + finally: + self._override_deps_stack.pop() + def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDeps] ) -> _system_prompt.SystemPromptFunc[AgentDeps]: @@ -551,6 +565,18 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: msg = 'No tools available.' return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}') + def _get_deps(self, deps: AgentDeps) -> AgentDeps: + """Get deps for a run. + + If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call. + + We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. + """ + try: + return self._override_deps_stack[-1] + except IndexError: + return deps + @dataclass class _MarkFinalResult(Generic[ResultData]): diff --git a/tests/test_deps.py b/tests/test_deps.py new file mode 100644 index 000000000..158da715c --- /dev/null +++ b/tests/test_deps.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +from pydantic_ai import Agent, CallContext +from pydantic_ai.models.test import TestModel + + +@dataclass +class MyDeps: + foo: int + bar: int + + +agent = Agent(TestModel(), deps_type=MyDeps) + + +@agent.retriever_context +async def test_retriever(ctx: CallContext[MyDeps]) -> str: + return f'{ctx.deps}' + + +def test_deps_used(): + result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) + assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}' + + +def test_deps_override(): + with agent.override_deps(MyDeps(foo=3, bar=4)): + result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) + assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}' + + with agent.override_deps(MyDeps(foo=5, bar=6)): + result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) + assert result.data == '{"test_retriever":"MyDeps(foo=5, bar=6)"}' + + result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) + assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}' + + result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) + assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}' diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 42c91d00e..6ade228ec 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -15,8 +15,8 @@ class MyDeps: bar: int -typed_agent1 = Agent(deps_type=MyDeps, result_type=str) -assert_type(typed_agent1, Agent[MyDeps, str]) +typed_agent = Agent(deps_type=MyDeps, result_type=str) +assert_type(typed_agent, Agent[MyDeps, str]) @contextmanager @@ -29,53 +29,62 @@ def expect_error(error_type: type[Exception]) -> Iterator[None]: raise AssertionError('Expected an error') -@typed_agent1.retriever_context +@typed_agent.retriever_context async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str: assert_type(ctx.deps, MyDeps) total = ctx.deps.foo + ctx.deps.bar return f'{x} {total}' -@typed_agent1.retriever_plain +@typed_agent.retriever_plain def ok_retriever_plain(x: str) -> dict[str, str]: return {'x': x} -@typed_agent1.retriever_context +@typed_agent.retriever_context async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str: total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined] return f'{x} {total}' -@typed_agent1.retriever_context # type: ignore[arg-type] +@typed_agent.retriever_context # type: ignore[arg-type] async def bad_retriever2(ctx: CallContext[int], x: str) -> str: return f'{x} {ctx.deps}' -@typed_agent1.retriever_plain # type: ignore[arg-type] +@typed_agent.retriever_plain # type: ignore[arg-type] async def bad_retriever_return(x: int) -> list[int]: return [x] with expect_error(ValueError): - @typed_agent1.retriever_context # type: ignore[arg-type] + @typed_agent.retriever_context # type: ignore[arg-type] async def bad_retriever3(x: str) -> str: return x def run_sync() -> None: - result = typed_agent1.run_sync('testing') + result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) assert_type(result, RunResult[str]) assert_type(result.data, str) async def run_stream() -> None: - async with typed_agent1.run_stream('testing') as streamed_result: + async with typed_agent.run_stream('testing', deps=MyDeps(foo=1, bar=2)) as streamed_result: result_items = [chunk async for chunk in streamed_result.stream()] assert_type(result_items, list[str]) +def run_with_override() -> None: + with typed_agent.override_deps(MyDeps(1, 2)): + typed_agent.run_sync('testing', deps=MyDeps(3, 4)) + + # invalid deps + with typed_agent.override_deps(123): # type: ignore[arg-type] + typed_agent.run_sync('testing', deps=MyDeps(3, 4)) + + @dataclass class Foo: a: int