Skip to content

Commit

Permalink
Update for coerced existing types
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Jan 23, 2025
1 parent db8edb6 commit 9790c36
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"jsonpatch<2.0,>=1.33",
]
name = "trustcall"
version = "0.0.27"
version = "0.0.28"
description = "Tenacious & trustworthy tool calling built on LangGraph."
readme = "README.md"

Expand Down
58 changes: 53 additions & 5 deletions trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SchemaInstance(NamedTuple):
"""

record_id: str
schema_name: str
schema_name: str | Literal["__any__"]
record: Dict[str, Any]


Expand Down Expand Up @@ -653,7 +653,7 @@ def _setup(self, state: ExtractionState):
existing = state.existing
if not existing:
raise ValueError("No existing schemas provided.")
self._validate_existing(existing)
existing = self._validate_existing(existing)
schema_strings = []
if isinstance(existing, dict):
for k, v in existing.items():
Expand Down Expand Up @@ -720,7 +720,18 @@ def _teardown(
(e for e in existing if e[0] == json_doc_id),
)
if not tool_name:
raise ValueError("Could not find tool name")
raise ValueError(
f"Could not find tool name for json_doc_id {json_doc_id}"
)
except StopIteration:
logger.error(
f"Could not find existing schema in dict for {json_doc_id}"
)
if rt:
rt.error = (
f"Could not find existing schema for {json_doc_id}"
)
continue
except (ValueError, IndexError, TypeError):
logger.error(
f"Could not find existing schema in list for {json_doc_id}"
Expand Down Expand Up @@ -782,8 +793,10 @@ def _validate_existing(self, existing: ExistingType):
" with keys matching one of the provided tool names:"
f" {self._provided_tools}"
)
elif isinstance(existing, list):
return existing
if isinstance(existing, list):
# For list types, validate each item's schema_name
coerced = []
for i, item in enumerate(existing):
if isinstance(item, SchemaInstance):
if item.schema_name not in self.tools:
Expand All @@ -793,6 +806,7 @@ def _validate_existing(self, existing: ExistingType):
f"name. Provided: {item}, Expected: SchemaInstance"
f" with schema_name in {self._provided_tools}"
)
coerced.append(coerced)
elif isinstance(item, tuple) and len(item) == 3:
if item[1] not in self.tools:
raise ValueError(
Expand All @@ -801,12 +815,46 @@ def _validate_existing(self, existing: ExistingType):
f" Expected: Tuple(str, str, dict) with second"
f" element in {self._provided_tools}"
)
coerced.append(SchemaInstance(item[0], item[1], item[2]))
elif isinstance(item, tuple) and len(item) == 2:
# Assume record_ID, item
if hasattr(item[1], "__name__"):
schema_name = item[1].__name__
else:
schema_name = item[1].__repr_name__()

if schema_name not in self.tools:
raise ValueError(
f"Schema name '{schema_name}' at index {i} does"
f" not match any tool name. Provided: {item},"
f" Expected: Tuple(str, str, dict) with second"
f" element in {self._provided_tools}"
)
val = (
item[1].model_dump(mode="json")
if isinstance(item[1], BaseModel)
else item[1]
)
coerced.append(SchemaInstance(item[0], schema_name, val))
elif isinstance(item, BaseModel):
if hasattr(item, "__name__"):
schema_name = item.__name__
else:
schema_name = item.__repr_name__()
coerced.append(
SchemaInstance(
str(uuid.uuid4()),
schema_name,
item.model_dump(mode="json"),
)
)
else:
raise ValueError(
f"Invalid item at index {i} in existing list."
f" Provided: {item}, Expected: SchemaInstance"
f" or Tuple(str, str, dict)"
f" or Tuple(str, str, dict) or BaseModel"
)
return coerced
else:
raise ValueError(
f"Invalid type for existing. Provided: {type(existing)},"
Expand Down

0 comments on commit 9790c36

Please sign in to comment.