Skip to content

Commit

Permalink
Remove dead code paths (and flags) from Predict, Adapters, and Settin…
Browse files Browse the repository at this point in the history
…gs (stanfordnlp#1944)

* Simplify built-in modules (remove extended_signature and new_signature) and remove assertions temporarily

* Update test_retry.py

* Fully turn CoT into a typical Module

* Remove dead code paths (and flags) from Predict, Adapters, and Settings
  • Loading branch information
okhat authored Dec 16, 2024
1 parent ae86009 commit 741ac10
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 44 deletions.
10 changes: 5 additions & 5 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs_ = self.format(signature, demos, inputs)
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)

Expand All @@ -27,7 +27,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = self.parse(signature, output, _parse_values=_parse_values)
value = self.parse(signature, output)

assert set(value.keys()) == set(signature.output_fields.keys()), \
f"Expected {signature.output_fields.keys()} but got {value.keys()}"
Expand All @@ -41,16 +41,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):

except Exception as e:
from .json_adapter import JSONAdapter
if _parse_values and not isinstance(self, JSONAdapter):
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
if not isinstance(self, JSONAdapter):
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
raise e

@abstractmethod
def format(self, signature, demos, inputs):
raise NotImplementedError

@abstractmethod
def parse(self, signature, completion, _parse_values):
def parse(self, signature, completion):
raise NotImplementedError

def format_finetune_data(self, signature, demos, inputs, outputs):
Expand Down
4 changes: 2 additions & 2 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
messages.append(format_turn(signature, inputs, role="user"))
return messages

def parse(self, signature, completion, _parse_values=True):
def parse(self, signature, completion):
sections = [(None, [])]

for line in completion.splitlines():
Expand All @@ -74,7 +74,7 @@ def parse(self, signature, completion, _parse_values=True):
for k, v in sections:
if (k not in fields) and (k in signature.output_fields):
try:
fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
fields[k] = parse_value(v, signature.output_fields[k].annotation)
except Exception as e:
raise ValueError(
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
Expand Down
6 changes: 3 additions & 3 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)

Expand All @@ -58,7 +58,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
values = []

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
value = self.parse(signature, output)
assert set(value.keys()) == set(
signature.output_fields.keys()
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
Expand Down Expand Up @@ -90,7 +90,7 @@ def format(self, signature, demos, inputs):

return messages

def parse(self, signature, completion, _parse_values=True):
def parse(self, signature, completion):
fields = json_repair.loads(completion)
fields = {k: v for k, v in fields.items() if k in signature.output_fields}

Expand Down
7 changes: 0 additions & 7 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,11 @@
adapter=None,
rm=None,
branch_idx=0,
reranker=None,
compiled_lm=None,
force_reuse_cached_compilation=False,
compiling=False,
skip_logprobs=False,
trace=[],
release=0,
bypass_assert=False,
bypass_suggest=False,
assert_failures=0,
suggest_failures=0,
langchain_history=[],
experimental=False,
backoff_time=10,
callbacks=[],
Expand Down
27 changes: 5 additions & 22 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging
import random
from functools import lru_cache

from pydantic import BaseModel

Expand All @@ -12,18 +10,12 @@
from dspy.utils.callback import with_callbacks


@lru_cache(maxsize=None)
def warn_once(msg: str):
logging.warning(msg)


class Predict(Module, Parameter):
def __init__(self, signature, _parse_values=True, callbacks=None, **config):
def __init__(self, signature, callbacks=None, **config):
self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self._parse_values = _parse_values
self.reset()

def reset(self):
Expand Down Expand Up @@ -80,7 +72,7 @@ def load_state(self, state):
self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
self.signature = self.signature.load_state(state["extended_signature"])
raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+")

return self

Expand All @@ -90,7 +82,6 @@ def __call__(self, **kwargs):

def forward(self, **kwargs):
import dspy
assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True"

# Extract the three privileged keyword arguments.
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
Expand All @@ -115,7 +106,9 @@ def forward(self, **kwargs):
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)
import dspy
adapter = dspy.settings.adapter or dspy.ChatAdapter()
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)

pred = Prediction.from_completions(completions, signature=signature)

Expand All @@ -135,16 +128,6 @@ def __repr__(self):
return f"{self.__class__.__name__}({self.signature})"


def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
import dspy

adapter = dspy.settings.adapter or dspy.ChatAdapter()

return adapter(
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
)


# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
Expand Down
6 changes: 1 addition & 5 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,6 @@ def _train(self):
sample_size = max(0, sample_size)

raw_demos = rng.sample(raw_demos, sample_size)

if dspy.settings.release >= 20230928:
predictor.demos = raw_demos + augmented_demos
else:
predictor.demos = augmented_demos + raw_demos
predictor.demos = augmented_demos + raw_demos

return self.student

0 comments on commit 741ac10

Please sign in to comment.