diff --git a/pyproject.toml b/pyproject.toml index 8d42eaf..cf94402 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "jsonpatch<2.0,>=1.33", ] name = "trustcall" -version = "0.0.32" +version = "0.0.34" description = "Tenacious & trustworthy tool calling built on LangGraph." readme = "README.md" diff --git a/tests/unit_tests/test_extraction.py b/tests/unit_tests/test_extraction.py index 0d85e1c..5863742 100644 --- a/tests/unit_tests/test_extraction.py +++ b/tests/unit_tests/test_extraction.py @@ -422,14 +422,13 @@ def test_validate_existing(existing, tools, is_valid): extractor._validate_existing(existing) -@pytest.mark.asyncio @pytest.mark.parametrize("strict_mode", [True, False, "ignore"]) async def test_e2e_existing_schema_policy_behavior(strict_mode): class MyRecognizedSchema(BaseModel): """A recognized schema that the pipeline can handle.""" - user_id: str - notes: str + user_id: str # type: ignore + notes: str # type: ignore # Our existing data includes 2 top-level keys: recognized, unknown existing_schemas = { @@ -537,14 +536,13 @@ class MyRecognizedSchema(BaseModel): assert recognized_item.notes == "updated notes" -@pytest.mark.asyncio @pytest.mark.parametrize("strict_mode", [True, False, "ignore"]) async def test_e2e_existing_schema_policy_tuple_behavior(strict_mode): class MyRecognizedSchema(BaseModel): """A recognized schema that the pipeline can handle.""" - user_id: str - notes: str + user_id: str # type: ignore + notes: str # type: ignore existing_schemas = [ ( @@ -655,7 +653,6 @@ class MyRecognizedSchema(BaseModel): assert recognized_item.notes == "updated notes" -@pytest.mark.asyncio @pytest.mark.parametrize("enable_inserts", [True, False]) async def test_enable_deletes_flow(enable_inserts: bool) -> None: class MySchema(BaseModel): diff --git a/tests/unit_tests/test_strict_existing.py b/tests/unit_tests/test_strict_existing.py index c22abdc..719c540 100644 --- a/tests/unit_tests/test_strict_existing.py +++ b/tests/unit_tests/test_strict_existing.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import patch import pytest from langchain_openai import ChatOpenAI @@ -260,7 +261,8 @@ def test_validate_existing_strictness( ): """Test various scenarios of validation.""" tools = {"DummySchema": DummySchema} - llm = ChatOpenAI(model="gpt-4o-mini") + with patch.dict("os.environ", {"OPENAI_API_KEY": "fake-api-key"}): + llm = ChatOpenAI(model="gpt-4o-mini") extractor = _ExtractUpdates( llm=llm, # We won't actually call the LLM here but we need it for parsing. tools=tools, diff --git a/trustcall/_base.py b/trustcall/_base.py index 6b16cd8..19b2725 100644 --- a/trustcall/_base.py +++ b/trustcall/_base.py @@ -664,7 +664,7 @@ def __init__( existing_schema_policy: bool | Literal["ignore"] = True, ): new_tools: list = [PatchDoc] - tool_choice = PatchDoc.__name__ if not enable_deletes else "any" + tool_choice = "PatchDoc" if not enable_deletes else "any" if enable_inserts: # Also let the LLM know that we can extract NEW schemas. tools_ = [ schema @@ -1052,7 +1052,7 @@ def _tear_down( async def ainvoke( self, state: ExtendedExtractState, config: RunnableConfig - ) -> dict: + ) -> Command[Literal["sync", "__end__"]]: """Generate a JSONPatch to correct the validation error and heal the tool call. Assumptions: @@ -1075,7 +1075,9 @@ async def ainvoke( goto=("sync",), ) - def invoke(self, state: ExtendedExtractState, config: RunnableConfig) -> dict: + def invoke( + self, state: ExtendedExtractState, config: RunnableConfig + ) -> Command[Literal["sync", "__end__"]]: try: msg = self.bound.invoke(state.messages, config) except Exception: diff --git a/uv.lock b/uv.lock index ce3c552..8f5b2e1 100644 --- a/uv.lock +++ b/uv.lock @@ -1725,7 +1725,7 @@ wheels = [ [[package]] name = "trustcall" -version = "0.0.31" +version = "0.0.33" source = { virtual = "." } dependencies = [ { name = "dydantic" },