Skip to content

Commit

Permalink
allow overriding deps, e.g. in testing (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 12, 2024
1 parent 265a7a7 commit a8edf02
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/api/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- run_sync
- run_stream
- model
- override_deps
- last_run_messages
- system_prompt
- retriever_plain
Expand Down
42 changes: 34 additions & 8 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations as _annotations

import asyncio
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, final, overload

Expand Down Expand Up @@ -52,6 +52,7 @@ class Agent(Generic[AgentDeps, ResultData]):
_deps_type: type[AgentDeps]
_max_result_retries: int
_current_result_retry: int
_override_deps_stack: list[AgentDeps]
last_run_messages: list[_messages.Message] | None = None
"""The messages from the last run, useful when a run raised an exception.
Expand Down Expand Up @@ -103,14 +104,15 @@ def __init__(
self._max_result_retries = result_retries if result_retries is not None else retries
self._current_result_retry = 0
self._result_validators = []
self._override_deps_stack = []

async def run(
self,
user_prompt: str,
*,
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
deps: AgentDeps = None,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Expand All @@ -125,8 +127,7 @@ async def run(
"""
model_used, custom_model, agent_model = await self._get_agent_model(model)

# we could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope
deps = cast(AgentDeps, deps)
deps = self._get_deps(deps)

new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages
Expand Down Expand Up @@ -180,7 +181,7 @@ def run_sync(
*,
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
deps: AgentDeps = None,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
Expand All @@ -204,7 +205,7 @@ async def run_stream(
*,
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
deps: AgentDeps = None,
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Expand All @@ -219,7 +220,7 @@ async def run_stream(
"""
model_used, custom_model, agent_model = await self._get_agent_model(model)

deps = cast(AgentDeps, deps)
deps = self._get_deps(deps)

new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
self.last_run_messages = messages
Expand Down Expand Up @@ -273,6 +274,19 @@ async def run_stream(
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()

@contextmanager
def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, this is particularly useful when testing.
Args:
overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
"""
self._override_deps_stack.append(overriding_deps)
try:
yield
finally:
self._override_deps_stack.pop()

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
Expand Down Expand Up @@ -551,6 +565,18 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
msg = 'No tools available.'
return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')

def _get_deps(self, deps: AgentDeps) -> AgentDeps:
"""Get deps for a run.
If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call.
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
"""
try:
return self._override_deps_stack[-1]
except IndexError:
return deps


@dataclass
class _MarkFinalResult(Generic[ResultData]):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass

from pydantic_ai import Agent, CallContext
from pydantic_ai.models.test import TestModel


@dataclass
class MyDeps:
foo: int
bar: int


agent = Agent(TestModel(), deps_type=MyDeps)


@agent.retriever_context
async def test_retriever(ctx: CallContext[MyDeps]) -> str:
return f'{ctx.deps}'


def test_deps_used():
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'


def test_deps_override():
with agent.override_deps(MyDeps(foo=3, bar=4)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'

with agent.override_deps(MyDeps(foo=5, bar=6)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=5, bar=6)"}'

result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'

result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'
29 changes: 19 additions & 10 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class MyDeps:
bar: int


typed_agent1 = Agent(deps_type=MyDeps, result_type=str)
assert_type(typed_agent1, Agent[MyDeps, str])
typed_agent = Agent(deps_type=MyDeps, result_type=str)
assert_type(typed_agent, Agent[MyDeps, str])


@contextmanager
Expand All @@ -29,53 +29,62 @@ def expect_error(error_type: type[Exception]) -> Iterator[None]:
raise AssertionError('Expected an error')


@typed_agent1.retriever_context
@typed_agent.retriever_context
async def ok_retriever(ctx: CallContext[MyDeps], x: str) -> str:
assert_type(ctx.deps, MyDeps)
total = ctx.deps.foo + ctx.deps.bar
return f'{x} {total}'


@typed_agent1.retriever_plain
@typed_agent.retriever_plain
def ok_retriever_plain(x: str) -> dict[str, str]:
return {'x': x}


@typed_agent1.retriever_context
@typed_agent.retriever_context
async def bad_retriever1(ctx: CallContext[MyDeps], x: str) -> str:
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
return f'{x} {total}'


@typed_agent1.retriever_context # type: ignore[arg-type]
@typed_agent.retriever_context # type: ignore[arg-type]
async def bad_retriever2(ctx: CallContext[int], x: str) -> str:
return f'{x} {ctx.deps}'


@typed_agent1.retriever_plain # type: ignore[arg-type]
@typed_agent.retriever_plain # type: ignore[arg-type]
async def bad_retriever_return(x: int) -> list[int]:
return [x]


with expect_error(ValueError):

@typed_agent1.retriever_context # type: ignore[arg-type]
@typed_agent.retriever_context # type: ignore[arg-type]
async def bad_retriever3(x: str) -> str:
return x


def run_sync() -> None:
result = typed_agent1.run_sync('testing')
result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2))
assert_type(result, RunResult[str])
assert_type(result.data, str)


async def run_stream() -> None:
async with typed_agent1.run_stream('testing') as streamed_result:
async with typed_agent.run_stream('testing', deps=MyDeps(foo=1, bar=2)) as streamed_result:
result_items = [chunk async for chunk in streamed_result.stream()]
assert_type(result_items, list[str])


def run_with_override() -> None:
with typed_agent.override_deps(MyDeps(1, 2)):
typed_agent.run_sync('testing', deps=MyDeps(3, 4))

# invalid deps
with typed_agent.override_deps(123): # type: ignore[arg-type]
typed_agent.run_sync('testing', deps=MyDeps(3, 4))


@dataclass
class Foo:
a: int
Expand Down

0 comments on commit a8edf02

Please sign in to comment.