Skip to content

Commit

Permalink
Improve ChatAdapter's handling of typed values and Pydantic models (s…
Browse files Browse the repository at this point in the history
…tanfordnlp#1663)

* Improve ChatAdapter's handling of typed values and Pydantic models

* Fixes for Literal

* Fixes for formatting complex-typed values
  • Loading branch information
okhat authored Oct 21, 2024
1 parent 313aa66 commit fe3d9d1
Showing 1 changed file with 61 additions and 26 deletions.
87 changes: 61 additions & 26 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
import re
import ast
import json
import re
import enum
import inspect
import pydantic
import textwrap
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

import pydantic
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
from .base import Adapter

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")


class FieldInfoWithName(NamedTuple):
"""
A tuple containing a field name and its corresponding FieldInfo object.
"""

name: str
info: FieldInfo


# Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat
# thread) has been completed.
# Built-in field indicating that a chat turn has been completed.
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())


Expand Down Expand Up @@ -114,6 +111,16 @@ def format_input_list_field_value(value: List[Any]) -> str:
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value

def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
Expand All @@ -125,24 +132,17 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
Returns:
The formatted value of the field, represented as a string.
"""
dspy_field_type: Literal["input", "output"] = get_dspy_field_type(field_info)
if isinstance(value, list):
if dspy_field_type == "input" or field_info.annotation is str:
# If the field is an input field or has no special type requirements, format it as
# numbered list so that it's organized in a way suitable for presenting long context
# to an LLM (i.e. not JSON)
return format_input_list_field_value(value)
else:
# If the field is an output field that has strict parsing requirements, format the
# value as a stringified JSON Array. This ensures that downstream routines can parse
# the field value correctly using methods from the `ujson` or `json` packages.
return json.dumps(value)
elif isinstance(value, pydantic.BaseModel):
return value.model_dump_json()

if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbere list for the LM.
return format_input_list_field_value(value)
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
return json.dumps(_serialize_for_json(value))
else:
return str(value)



def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
Expand All @@ -166,15 +166,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
def parse_value(value, annotation):
if annotation is str:
return str(value)

parsed_value = value
if isinstance(value, str):

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
try:
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value

return TypeAdapter(annotation).validate_python(parsed_value)


Expand Down Expand Up @@ -222,6 +227,16 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
content.append(formatted_fields)

if role == "user":
# def type_info(v):
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
# if v.annotation is not str else ""
#
# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
# )

content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
Expand Down Expand Up @@ -260,10 +275,30 @@ def prepare_instructions(signature: SignatureMeta):
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
type_ = field_info.annotation

if get_dspy_field_type(field_info) == 'input' or type_ is str:
desc = ""
elif type_ is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc= f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())

desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"

def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
return format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): f"{{{field_name}}}"
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
}
)
Expand Down

0 comments on commit fe3d9d1

Please sign in to comment.