Skip to content

Commit

Permalink
Merge pull request stanfordnlp#1413 from LLukas22/pydantic-optionals
Browse files Browse the repository at this point in the history
fix(dspy): Allow `Union` with pydantic objects in `TypedPredictor`
  • Loading branch information
arnavsinghvi11 authored Sep 19, 2024
2 parents fb823c7 + cec972e commit 2cc029b
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 44 deletions.
47 changes: 21 additions & 26 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def __repr__(self):
"""Return a string representation of the TypedPredictor object."""
return f"TypedPredictor({self.signature})"

def _make_example(self, type_) -> str:
def _make_example(self, field) -> str:
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
schema = json.dumps(type_.model_json_schema())
if hasattr(field, "model_json_schema"):
pass
schema = field.json_schema_extra["schema"]
parser = field.json_schema_extra["parser"]
if self.wrap_json:
schema = "```json\n" + schema + "\n```\n"
json_object = dspy.Predict(
Expand All @@ -127,9 +130,9 @@ def _make_example(self, type_) -> str:
"Make a very succinct json object that validates with the following schema",
),
)(json_schema=schema).json_object
# We use the model_validate_json method to make sure the example is valid
# We use the parser to make sure the json object is valid.
try:
type_.model_validate_json(_unwrap_json(json_object, type_.model_validate_json))
parser(_unwrap_json(json_object, parser))
except (pydantic.ValidationError, ValueError):
return "" # Unable to make an example
return json_object
Expand Down Expand Up @@ -225,32 +228,21 @@ def parse(x):
format=lambda x: x if isinstance(x, str) else str(x),
parser=type_,
)
elif False:
# TODO: I don't like forcing the model to write "value" in the output.
if not (inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel)):
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
to_json = lambda x, type_=type_: type_(value=x).model_dump_json()[9:-1] # {"value":"123"}
from_json = lambda x, type_=type_: type_.model_validate_json('{"value":' + x + "}").value
schema = json.dumps(type_.model_json_schema()["properties"]["value"])
else:
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
schema = json.dumps(type_.model_json_schema())
else:
# Anything else we wrap in a pydantic object
if not (
if (
inspect.isclass(type_)
and typing.get_origin(type_) not in (list, tuple) # To support Python 3.9
and issubclass(type_, pydantic.BaseModel)
):
type_ = pydantic.create_model("Output", value=(type_, ...), __base__=pydantic.BaseModel)
to_json = lambda x, type_=type_: type_(value=x).model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x).value
schema = json.dumps(type_.model_json_schema())
else:
to_json = lambda x: x.model_dump_json()
from_json = lambda x, type_=type_: type_.model_validate_json(x)
schema = json.dumps(type_.model_json_schema())
else:
adapter = pydantic.TypeAdapter(type_)
to_json = lambda x: adapter.serializer.to_json(x)
from_json = lambda x, type_=adapter: type_.validate_json(x)
schema = json.dumps(adapter.json_schema())
if self.wrap_json:
to_json = lambda x, inner=to_json: "```json\n" + inner(x) + "\n```\n"
schema = "```json\n" + schema + "\n```"
Expand All @@ -260,6 +252,7 @@ def parse(x):
+ (". Respond with a single JSON object. JSON Schema: " + schema),
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x, from_json)),
schema=schema,
type_=type_,
)
else: # If input field
Expand Down Expand Up @@ -321,7 +314,7 @@ def forward(self, **kwargs) -> dspy.Prediction:
if (
try_i + 1 < self.max_retries
and prefix not in current_desc
and (example := self._make_example(field.annotation))
and (example := self._make_example(field))
):
signature = signature.with_updated_fields(
name,
Expand Down Expand Up @@ -405,9 +398,13 @@ def _func_to_signature(func):
return dspy.Signature(fields, instructions)


def _unwrap_json(output, from_json: Callable[[str], Union[pydantic.BaseModel, str]]):
def _unwrap_json(output, from_json: Callable[[str], Union[pydantic.BaseModel, str, None]]):
try:
return from_json(output).model_dump_json()
parsing_result = from_json(output)
if isinstance(parsing_result, pydantic.BaseModel):
return parsing_result.model_dump_json()
else:
return output
except (ValueError, pydantic.ValidationError, AttributeError):
output = output.strip()
if output.startswith("```"):
Expand All @@ -416,6 +413,4 @@ def _unwrap_json(output, from_json: Callable[[str], Union[pydantic.BaseModel, st
if not output.endswith("```"):
raise ValueError("Don't write anything after the final json ```") from None
output = output[7:-3].strip()
if not output.startswith("{") or not output.endswith("}"):
raise ValueError("json output should start and end with { and }") from None
return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json
42 changes: 33 additions & 9 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def hard_questions(topics: List[str]) -> List[str]:
pass

expected = ["What is the speed of light?", "What is the speed of sound?"]
lm = DummyLM(['{"value": ["What is the speed of light?", "What is the speed of sound?"]}'])
lm = DummyLM(['["What is the speed of light?", "What is the speed of sound?"]'])
dspy.settings.configure(lm=lm)

question = hard_questions(topics=["Physics", "Music"])
Expand Down Expand Up @@ -557,7 +557,7 @@ def test_parse_type_string():


def test_literal():
lm = DummyLM([f'{{"value": "{i}"}}' for i in range(100)])
lm = DummyLM(['"2"', '"3"'])
dspy.settings.configure(lm=lm)

@predictor
Expand All @@ -567,8 +567,22 @@ def f() -> Literal["2", "3"]:
assert f() == "2"


def test_literal_missmatch():
lm = DummyLM([f'"{i}"' for i in range(5, 100)])
dspy.settings.configure(lm=lm)

@predictor(max_retries=1)
def f() -> Literal["2", "3"]:
pass

with pytest.raises(Exception) as e_info:
f()

assert e_info.value.args[1]["f"] == "Input should be '2' or '3': (error type: literal_error)"


def test_literal_int():
lm = DummyLM([f'{{"value": {i}}}' for i in range(100)])
lm = DummyLM(["2", "3"])
dspy.settings.configure(lm=lm)

@predictor
Expand All @@ -578,6 +592,20 @@ def f() -> Literal[2, 3]:
assert f() == 2


def test_literal_int_missmatch():
lm = DummyLM([f"{i}" for i in range(5, 100)])
dspy.settings.configure(lm=lm)

@predictor(max_retries=1)
def f() -> Literal[2, 3]:
pass

with pytest.raises(Exception) as e_info:
f()

assert e_info.value.args[1]["f"] == "Input should be 2 or 3: (error type: literal_error)"


def test_fields_on_base_signature():
class SimpleOutput(dspy.Signature):
output: float = dspy.OutputField(gt=0, lt=1)
Expand Down Expand Up @@ -896,9 +924,7 @@ def _test_demos_missing_input():


def test_conlist():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)
dspy.settings.configure(lm=DummyLM(["[]", "[1]", "[1, 2]", "[1, 2, 3]"]))

@predictor
def make_numbers(input: str) -> Annotated[list[int], Field(min_items=2)]:
Expand All @@ -908,9 +934,7 @@ def make_numbers(input: str) -> Annotated[list[int], Field(min_items=2)]:


def test_conlist2():
dspy.settings.configure(
lm=DummyLM(['{"value": []}', '{"value": [1]}', '{"value": [1, 2]}', '{"value": [1, 2, 3]}'])
)
dspy.settings.configure(lm=DummyLM(["[]", "[1]", "[1, 2]", "[1, 2, 3]"]))

make_numbers = TypedPredictor("input:str -> output:Annotated[List[int], Field(min_items=2)]")
assert make_numbers(input="What are the first two numbers?").output == [1, 2]
Expand Down
14 changes: 5 additions & 9 deletions tests/functional/test_signature_opt_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from dspy.evaluate import Evaluate
from dspy.evaluate.metrics import answer_exact_match
from dspy.functional import TypedPredictor

import json
from pydantic_core import to_jsonable_python

hotpotqa = [
ex.with_inputs("question")
Expand Down Expand Up @@ -109,7 +110,7 @@ class BasicQA(dspy.Signature):
[
# Seed prompts
"some thoughts",
'{"value": [{"instructions": "I", "question_desc": "$q", "question_prefix": "Q:", "answer_desc": "$a", "answer_prefix": "A:"}]}',
'[{"instructions": "I", "question_desc": "$q", "question_prefix": "Q:", "answer_desc": "$a", "answer_prefix": "A:"}]',
]
)
dspy.settings.configure(lm=qa_model)
Expand Down Expand Up @@ -163,18 +164,13 @@ class ExpectedSignature2(dspy.Signature):

info2 = make_info(ExpectedSignature2)

T = TypeVar("T")

class OutputWrapper(pydantic.BaseModel, Generic[T]):
value: list[T]

qa_model = DummyLM([])
prompt_model = DummyLM(
[
"some thoughts",
OutputWrapper[type(info1)](value=[info1]).model_dump_json(),
json.dumps([to_jsonable_python(info1)]),
"some thoughts",
OutputWrapper[type(info2)](value=[info2]).model_dump_json(),
json.dumps([to_jsonable_python(info2)]),
]
)
dspy.settings.configure(lm=qa_model)
Expand Down
154 changes: 154 additions & 0 deletions tests/functional/test_signature_typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Any, Optional, Union
import pytest

import pydantic
import dspy
from dspy.functional import TypedPredictor
from dspy.signatures.signature import signature_to_template


def get_field_and_parser(signature: dspy.Signature) -> tuple[Any, Any]:
module = TypedPredictor(signature)
signature = module._prepare_signature()
assert "answer" in signature.fields, "'answer' not in signature.fields"
field = signature.fields.get("answer")
parser = field.json_schema_extra.get("parser")
return field, parser


class Mysubmodel(pydantic.BaseModel):
sub_floating: float


class MyModel(pydantic.BaseModel):
floating: float
string: str
boolean: bool
integer: int
optional: Optional[str]
sequence_of_strings: list[str]
union: Union[str, float]
submodel: Mysubmodel
optional_submodel: Optional[Mysubmodel]
optional_existing_submodule: Optional[Mysubmodel]


def build_model_instance() -> MyModel:
return MyModel(
floating=3.14,
string="foobar",
boolean=True,
integer=42,
optional=None,
sequence_of_strings=["foo", "bar"],
union=3.14,
submodel=Mysubmodel(sub_floating=42.42),
optional_submodel=None,
optional_existing_submodule=Mysubmodel(sub_floating=42.42),
)


@pytest.mark.parametrize(
"test_type,serialized, expected", [(str, "foo", "foo"), (int, "42", 42), (float, "42.42", 42.42)]
)
def test_basic_types(test_type: type, serialized: str, expected: Any):
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: test_type = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)
assert parser is test_type, "Parser is not correct for 'answer'"
assert parser(serialized) == expected, f"{test_type}({serialized})!= {expected}"


def test_boolean():
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: bool = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)
assert parser("true"), f"Parsing 'true' failed"
assert not parser("false"), f"Parsing 'false' failed"


@pytest.mark.parametrize(
"test_type,serialized, expected",
[(list[str], '["foo", "bar"]', ["foo", "bar"]), (tuple[int, float], "[42, 3.14]", (42, 3.14))],
)
def test_sequences(test_type: type, serialized: str, expected: Any):
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: test_type = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

assert parser(serialized) == expected, f"Parsing {expected} failed"


@pytest.mark.parametrize(
"test_type,serialized, expected",
[
(Optional[str], '"foobar"', "foobar"),
(Optional[str], "null", None),
(Union[str, float], "3.14", 3.14),
(Union[str, bool], "true", True),
],
)
def test_unions(test_type: type, serialized: str, expected: Any):
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: test_type = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

assert parser(serialized) == expected, f"Parsing {expected} failed"


def test_pydantic():
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: MyModel = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

instance = build_model_instance()
parsed_instance = parser(instance.model_dump_json())

assert parsed_instance == instance, f"{instance} != {parsed_instance}"


def test_optional_pydantic():
class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: Optional[MyModel] = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

instance = build_model_instance()
parsed_instance = parser(instance.model_dump_json())
assert parsed_instance == instance, f"{instance} != {parsed_instance}"

# Check null case
parsed_instance = parser("null")
assert parsed_instance == None, "Optional[MyModel] should be None"


def test_dataclass():
from dataclasses import dataclass

@dataclass(frozen=True)
class MyDataclass:
string: str
number: int
floating: float
boolean: bool

class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: MyDataclass = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

instance = MyDataclass("foobar", 42, 3.14, True)
parsed_instance = parser('{"string": "foobar", "number": 42, "floating": 3.14, "boolean": true}')
assert parsed_instance == instance, f"{instance} != {parsed_instance}"

0 comments on commit 2cc029b

Please sign in to comment.