Skip to content

Commit

Permalink
make all_messages() a method
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 4, 2024
1 parent 10362a3 commit 3c0f4d7
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 33 deletions.
2 changes: 1 addition & 1 deletion preview.README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ print(result.response)

# `result.all_messages` includes details of messages exchanged, useful if you want to continue
# the conversation later, via the `message_history` argument of `run_sync`.
print(result.all_messages)
print(result.all_messages())
```
2 changes: 1 addition & 1 deletion pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def run(
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return shared.RunResult(left.value, messages, new_message_index, cost)
return shared.RunResult(left.value, cost, messages, new_message_index)
else:
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
Expand Down
13 changes: 9 additions & 4 deletions pydantic_ai/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,25 @@ class RunResult(Generic[ResultData]):
"""Result of a run."""

response: ResultData
all_messages: list[messages.Message]
new_message_index: int
cost: Cost
_all_messages: list[messages.Message]
_new_message_index: int

def all_messages(self) -> list[messages.Message]:
"""Return the history of messages."""
# this is a method to be consistent with the other methods
return self._all_messages

def all_messages_json(self) -> bytes:
"""Return the history of messages as JSON bytes."""
return messages.MessagesTypeAdapter.dump_json(self.all_messages)
return messages.MessagesTypeAdapter.dump_json(self.all_messages())

def new_messages(self) -> list[messages.Message]:
"""Return new messages associated with this run.
System prompts and any messages from older runs are excluded.
"""
return self.all_messages[self.new_message_index :]
return self.all_messages()[self._new_message_index :]

def new_messages_json(self) -> bytes:
"""Return new messages from [new_messages][] as JSON bytes."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def stream_messages():
database.add_messages(response.new_messages_json())
# stream the last message which will be the agent response, we can't just yield `new_messages_json()`
# since we already stream the user prompt
yield MessageTypeAdapter.dump_json(response.all_messages[-1]) + b'\n'
yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n'

return StreamingResponse(stream_messages(), media_type='text/plain')

Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ async def test_request_simple_success(get_gemini_client: GetGeminiClient):

result = await agent.run('Hello')
assert result.response == 'Hello world'
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
Expand All @@ -381,7 +381,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):

result = await agent.run('Hello')
assert result.response == [1, 2, 123]
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand Down Expand Up @@ -424,7 +424,7 @@ async def get_location(loc_name: str) -> str:

result = await agent.run('Hello')
assert result.response == 'final response'
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
SystemPrompt(content='this is the system prompt'),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_simple():
agent = Agent(FunctionModel(return_last), deps=None)
result = agent.run_sync('Hello')
assert result.response == snapshot("content='Hello' role='user' message_count=1")
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(
content='Hello',
Expand All @@ -52,9 +52,9 @@ def test_simple():
]
)

result2 = agent.run_sync('World', message_history=result.all_messages)
result2 = agent.run_sync('World', message_history=result.all_messages())
assert result2.response == snapshot("content='World' role='user' message_count=3")
assert result2.all_messages == snapshot(
assert result2.all_messages() == snapshot(
[
UserPrompt(
content='Hello',
Expand Down Expand Up @@ -127,7 +127,7 @@ async def get_weather(_: CallContext[None], lat: int, lng: int):
def test_weather():
result = weather_agent.run_sync('London')
assert result.response == 'Raining in London'
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(
content='London',
Expand Down Expand Up @@ -322,7 +322,7 @@ def f(messages: list[Message], info: AgentInfo) -> LLMMessage:
def test_call_all():
result = agent_all.run_sync('Hello', model=TestModel())
assert result.response == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
SystemPrompt(content='foobar'),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def my_ret(x: int) -> str:
result = agent.run_sync('Hello', model=TestModel())
assert call_count == 2
assert result.response == snapshot('{"my_ret":"1"}')
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def test_request_structured_response():

result = await agent.run('Hello')
assert result.response == [1, 2, 123]
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)),
LLMToolCalls(
Expand Down Expand Up @@ -185,7 +185,7 @@ async def get_location(loc_name: str) -> str:

result = await agent.run('Hello')
assert result.response == 'final response'
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
SystemPrompt(content='this is the system prompt'),
UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)),
Expand Down
55 changes: 40 additions & 15 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> LLMMessage:
result = agent.run_sync('Hello')
assert isinstance(result.response, Foo)
assert result.response.model_dump() == {'a': 42, 'b': 'foo'}
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand Down Expand Up @@ -120,7 +120,7 @@ def validate_result(ctx: CallContext[None], r: Foo) -> Foo:
result = agent.run_sync('Hello')
assert isinstance(result.response, Foo)
assert result.response.model_dump() == {'a': 42, 'b': 'foo'}
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand Down Expand Up @@ -153,7 +153,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage:
result = agent.run_sync('Hello')
assert result.response == ('foo', 'bar')
assert call_index == 2
assert result.all_messages == snapshot(
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='hello', timestamp=IsNow(tz=timezone.utc)),
Expand Down Expand Up @@ -355,7 +355,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
assert got_tool_call_name == snapshot('final_result_Bar')


def test_run_with_history():
def test_run_with_history_new():
m = TestModel()

agent = Agent(m, deps=None, system_prompt='Foobar')
Expand All @@ -364,11 +364,25 @@ def test_run_with_history():
async def ret_a(x: str) -> str:
return f'{x}-apple'

result = agent.run_sync('Hello')
assert result == snapshot(
result1 = agent.run_sync('Hello')
assert result1.new_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
]
)

# if we pass new_messages, system prompt is inserted before the message_history messages
result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
assert result2 == snapshot(
RunResult(
response='{"ret_a":"a-apple"}',
all_messages=[
_all_messages=[
SystemPrompt(content='Foobar'),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand All @@ -377,17 +391,31 @@ async def ret_a(x: str) -> str:
),
ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
# second call, notice no repeated system prompt
UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
],
new_message_index=1,
_new_message_index=5,
cost=Cost(),
)
)
new_msg_roles = [msg.role for msg in result2.new_messages()]
assert new_msg_roles == snapshot(['user', 'llm-tool-calls', 'tool-return', 'llm-response'])
assert result2.new_messages_json().startswith(b'[{"content":"Hello again",')

result = agent.run_sync('Hello again', message_history=result.all_messages)
assert result == snapshot(
# if we pass all_messages, system prompt is NOT inserted before the message_history messages,
# so only one system prompt
result3 = agent.run_sync('Hello again', message_history=result1.all_messages())
# same as result2 except for datetimes
assert result3 == snapshot(
RunResult(
response='{"ret_a":"a-apple"}',
all_messages=[
_all_messages=[
SystemPrompt(content='Foobar'),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
Expand All @@ -405,10 +433,7 @@ async def ret_a(x: str) -> str:
ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
],
new_message_index=5,
_new_message_index=5,
cost=Cost(),
)
)
new_msg_roles = [msg.role for msg in result.new_messages()]
assert new_msg_roles == snapshot(['user', 'llm-tool-calls', 'tool-return', 'llm-response'])
assert result.new_messages_json().startswith(b'[{"content":"Hello again",')

0 comments on commit 3c0f4d7

Please sign in to comment.