diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index c2a6ce4..4bc4204 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -5,7 +5,8 @@ from pydantic import ValidationError from typing_extensions import Annotated, TypedDict -from trustcall._base import _convert_any_typed_dicts_to_pydantic +from trustcall._base import _convert_any_typed_dicts_to_pydantic, _apply_message_ops, MessageOp +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage def test_convert_any_typed_dicts_to_pydantic(): @@ -89,3 +90,114 @@ class RecursiveType(TypedDict): cyclic["next"] = cyclic with pytest.raises(ValueError): # or RecursionError, depending on implementation model(**cyclic) + + +def test_message_ops_update_tool_name(): + """Test various scenarios for updating tool names in messages.""" + + # Test case 1: Mixed message types + messages = [ + SystemMessage(content="system message"), + HumanMessage(content="user message"), + AIMessage( + content="", + tool_calls=[{ + "id": "tool1", + "name": "old_name", + "args": {"arg1": "value1"} + }] + ), + ToolMessage( + content="tool response", + tool_call_id="tool1", + name="old_name" + ) + ] + + message_ops = [ + MessageOp( + op="update_tool_name", + target={ + "id": "tool1", + "name": "new_name" + } + ) + ] + + result = _apply_message_ops(messages, message_ops) + + # Verify message count and types + assert len(result) == 4, "All messages should be preserved" + assert isinstance(result[0], SystemMessage) + assert isinstance(result[1], HumanMessage) + assert isinstance(result[2], AIMessage) + assert isinstance(result[3], ToolMessage) + + # Verify content preservation + assert result[0].content == "system message" + assert result[1].content == "user message" + assert result[2].tool_calls[0]["name"] == "new_name" + assert result[3].content == "tool response" + + # Test case 2: Multiple tool calls in single AIMessage + messages = [ + AIMessage( + content="", + tool_calls=[ + { + "id": "tool1", + "name": "old_name1", + "args": {"arg1": "value1"} + }, + { + "id": "tool2", + "name": "old_name2", + "args": {"arg2": "value2"} + } + ] + ) + ] + + message_ops = [ + MessageOp( + op="update_tool_name", + target={ + "id": "tool1", + "name": "new_name1" + } + ) + ] + + result = _apply_message_ops(messages, message_ops) + + # Verify selective update + assert len(result) == 1 + assert result[0].tool_calls[0]["name"] == "new_name1" # Updated + assert result[0].tool_calls[1]["name"] == "old_name2" # Unchanged + + # Test case 3: No matching tool_id + messages = [ + AIMessage( + content="", + tool_calls=[{ + "id": "tool1", + "name": "old_name", + "args": {"arg1": "value1"} + }] + ) + ] + + message_ops = [ + MessageOp( + op="update_tool_name", + target={ + "id": "non_existent_tool", + "name": "new_name" + } + ) + ] + + result = _apply_message_ops(messages, message_ops) + + # Verify no changes for non-matching tool_id + assert result[0].tool_calls[0]["name"] == "old_name" diff --git a/trustcall/_base.py b/trustcall/_base.py index b4858dc..f2ce01b 100644 --- a/trustcall/_base.py +++ b/trustcall/_base.py @@ -1205,6 +1205,8 @@ def _apply_message_ops( m = m.copy() m.tool_calls = new messages_.append(m) + else: + messages_.append(m) messages = messages_ else: