From 8dd128e10aeda3632821d12ec0c8b9e4cc3238d4 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Tue, 21 Jan 2025 13:57:55 -0800 Subject: [PATCH] Update to support --- pyproject.toml | 3 ++- tests/evals/test_evals.py | 11 +++++------ trustcall/_base.py | 13 ++++++++++++- uv.lock | 6 +++++- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cb08024..8eacaa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "jsonpatch<2.0,>=1.33", ] name = "trustcall" -version = "0.0.26" +version = "0.0.27" description = "Tenacious & trustworthy tool calling built on LangGraph." readme = "README.md" @@ -31,6 +31,7 @@ dev = [ "anyio>=4.7.0", "pytest-asyncio-cooperative>=0.37.0", ] +standard = ["langchain>=0.3"] [tool.setuptools] packages = ["trustcall"] diff --git a/tests/evals/test_evals.py b/tests/evals/test_evals.py index b823d31..e6e70dd 100644 --- a/tests/evals/test_evals.py +++ b/tests/evals/test_evals.py @@ -5,7 +5,6 @@ import langsmith as ls import pytest from dydantic import create_model_from_schema -from langchain.chat_models import init_chat_model from langsmith import aevaluate, expect, traceable from langsmith.evaluation import EvaluationResults from langsmith.schemas import Example, Run @@ -61,8 +60,7 @@ async def predict_with_model( ), ("user", inputs["input_str"]), ] - llm = init_chat_model(model_name, temperature=0.8) - extractor = create_extractor(llm, tools=[tool_def], tool_choice=tool_def["name"]) + extractor = create_extractor(model_name, tools=[tool_def], tool_choice=tool_def["name"]) existing = inputs.get("current_value", {}) extractor_inputs: dict = {"messages": messages} if existing: @@ -226,7 +224,7 @@ def query_docs(query: str) -> str: return "I am a document." extractor = create_extractor( - init_chat_model("gpt-4o"), tools=[query_docs], tool_choice="query_docs" + "gpt-4o", tools=[query_docs], tool_choice="query_docs" ) extractor.invoke({"messages": [("user", "What are the docs about?")]}) @@ -246,8 +244,9 @@ def validate_query_length(cls, v: str) -> str: ) return v - llm = init_chat_model("gpt-4o-mini") - extractor = create_extractor(llm, tools=[query_docs], tool_choice="any") + extractor = create_extractor( + "gpt-4o", tools=[query_docs], tool_choice="any" + ) extractor.invoke( { "messages": [ diff --git a/trustcall/_base.py b/trustcall/_base.py index 501f9a6..a0ae6dc 100644 --- a/trustcall/_base.py +++ b/trustcall/_base.py @@ -119,7 +119,7 @@ class ExtractionOutputs(TypedDict): def create_extractor( - llm: BaseChatModel, + llm: str | BaseChatModel, *, tools: Sequence[TOOL_T], tool_choice: Optional[str] = None, @@ -258,6 +258,17 @@ def create_extractor( ... } ... ) """ # noqa + if isinstance(llm, str): + try: + from langchain.chat_models import init_chat_model + except ImportError: + raise ImportError( + "Creating extractors from a string requires langchain>=0.3.0," + " as well as the provider-specific package" + " (like langchain-openai, langchain-anthropic, etc.)" + " Please install langchain to continue." + ) + llm = init_chat_model(llm) builder = StateGraph(ExtractionState) def format_exception(error: BaseException, call: ToolCall, schema: Type[BaseModel]): diff --git a/uv.lock b/uv.lock index 20e45e8..7f1b9d4 100644 --- a/uv.lock +++ b/uv.lock @@ -1658,7 +1658,7 @@ wheels = [ [[package]] name = "trustcall" -version = "0.0.26" +version = "0.0.27" source = { virtual = "." } dependencies = [ { name = "dydantic" }, @@ -1682,6 +1682,9 @@ dev = [ { name = "ruff" }, { name = "vcrpy" }, ] +standard = [ + { name = "langchain" }, +] [package.metadata] requires-dist = [ @@ -1706,6 +1709,7 @@ dev = [ { name = "ruff", specifier = ">=0.4.10,<1.0.0" }, { name = "vcrpy", specifier = ">=6.0.1,<7.0.0" }, ] +standard = [{ name = "langchain", specifier = ">=0.3" }] [[package]] name = "typing-extensions"