Skip to content

Commit

Permalink
fix: preserve non-AI messages in message operations
Browse files Browse the repository at this point in the history
When updating tool names, the message operation was dropping non-AI messages.
This fix ensures all message types are preserved during message operations,
maintaining conversation context and history.

- Add test cases to verify message preservation
- Update message operation logic to keep non-AI messages
  • Loading branch information
sjang42 committed Jan 29, 2025
1 parent 9790c36 commit 3aa1b19
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 1 deletion.
114 changes: 113 additions & 1 deletion tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic import ValidationError
from typing_extensions import Annotated, TypedDict

from trustcall._base import _convert_any_typed_dicts_to_pydantic
from trustcall._base import _convert_any_typed_dicts_to_pydantic, _apply_message_ops, MessageOp
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage


def test_convert_any_typed_dicts_to_pydantic():
Expand Down Expand Up @@ -89,3 +90,114 @@ class RecursiveType(TypedDict):
cyclic["next"] = cyclic
with pytest.raises(ValueError): # or RecursionError, depending on implementation
model(**cyclic)


def test_message_ops_update_tool_name():
"""Test various scenarios for updating tool names in messages."""

# Test case 1: Mixed message types
messages = [
SystemMessage(content="system message"),
HumanMessage(content="user message"),
AIMessage(
content="",
tool_calls=[{
"id": "tool1",
"name": "old_name",
"args": {"arg1": "value1"}
}]
),
ToolMessage(
content="tool response",
tool_call_id="tool1",
name="old_name"
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "tool1",
"name": "new_name"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify message count and types
assert len(result) == 4, "All messages should be preserved"
assert isinstance(result[0], SystemMessage)
assert isinstance(result[1], HumanMessage)
assert isinstance(result[2], AIMessage)
assert isinstance(result[3], ToolMessage)

# Verify content preservation
assert result[0].content == "system message"
assert result[1].content == "user message"
assert result[2].tool_calls[0]["name"] == "new_name"
assert result[3].content == "tool response"

# Test case 2: Multiple tool calls in single AIMessage
messages = [
AIMessage(
content="",
tool_calls=[
{
"id": "tool1",
"name": "old_name1",
"args": {"arg1": "value1"}
},
{
"id": "tool2",
"name": "old_name2",
"args": {"arg2": "value2"}
}
]
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "tool1",
"name": "new_name1"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify selective update
assert len(result) == 1
assert result[0].tool_calls[0]["name"] == "new_name1" # Updated
assert result[0].tool_calls[1]["name"] == "old_name2" # Unchanged

# Test case 3: No matching tool_id
messages = [
AIMessage(
content="",
tool_calls=[{
"id": "tool1",
"name": "old_name",
"args": {"arg1": "value1"}
}]
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "non_existent_tool",
"name": "new_name"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify no changes for non-matching tool_id
assert result[0].tool_calls[0]["name"] == "old_name"
2 changes: 2 additions & 0 deletions trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,8 @@ def _apply_message_ops(
m = m.copy()
m.tool_calls = new
messages_.append(m)
else:
messages_.append(m)
messages = messages_

else:
Expand Down

0 comments on commit 3aa1b19

Please sign in to comment.