Skip to content

Commit

Permalink
minor style fixes (stanfordnlp#1921)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub authored Dec 10, 2024
1 parent e690743 commit 474d496
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
63 changes: 36 additions & 27 deletions dspy/datasets/hotpotqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,71 @@


class HotPotQA(Dataset):
def __init__(self, *args, only_hard_examples=True, keep_details='dev_titles', unofficial_dev=True, **kwargs) -> None:
def __init__(
self,
*args,
only_hard_examples=True,
keep_details="dev_titles",
unofficial_dev=True,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
assert only_hard_examples, "Care must be taken when adding support for easy examples." \
"Dev must be all hard to match official dev, but training can be flexible."

hf_official_train = load_dataset("hotpot_qa", 'fullwiki', split='train', trust_remote_code=True)
hf_official_dev = load_dataset("hotpot_qa", 'fullwiki', split='validation', trust_remote_code=True)
assert only_hard_examples, (
"Care must be taken when adding support for easy examples."
"Dev must be all hard to match official dev, but training can be flexible."
)

hf_official_train = load_dataset("hotpot_qa", "fullwiki", split="train", trust_remote_code=True)
hf_official_dev = load_dataset("hotpot_qa", "fullwiki", split="validation", trust_remote_code=True)

official_train = []
for raw_example in hf_official_train:
if raw_example['level'] == 'hard':
if raw_example["level"] == "hard":
if keep_details is True:
keys = ['id', 'question', 'answer', 'type', 'supporting_facts', 'context']
elif keep_details == 'dev_titles':
keys = ['question', 'answer', 'supporting_facts']
keys = ["id", "question", "answer", "type", "supporting_facts", "context"]
elif keep_details == "dev_titles":
keys = ["question", "answer", "supporting_facts"]
else:
keys = ['question', 'answer']
keys = ["question", "answer"]

example = {k: raw_example[k] for k in keys}
if 'supporting_facts' in example:
example['gold_titles'] = set(example['supporting_facts']['title'])
del example['supporting_facts']

if "supporting_facts" in example:
example["gold_titles"] = set(example["supporting_facts"]["title"])
del example["supporting_facts"]

official_train.append(example)

rng = random.Random(0)
rng.shuffle(official_train)

self._train = official_train[:len(official_train)*75//100]
self._train = official_train[: len(official_train) * 75 // 100]

if unofficial_dev:
self._dev = official_train[len(official_train)*75//100:]
self._dev = official_train[len(official_train) * 75 // 100 :]
else:
self._dev = None

for example in self._train:
if keep_details == 'dev_titles':
del example['gold_titles']
if keep_details == "dev_titles":
del example["gold_titles"]

test = []
for raw_example in hf_official_dev:
assert raw_example['level'] == 'hard'
example = {k: raw_example[k] for k in ['id', 'question', 'answer', 'type', 'supporting_facts']}
if 'supporting_facts' in example:
example['gold_titles'] = set(example['supporting_facts']['title'])
del example['supporting_facts']
assert raw_example["level"] == "hard"
example = {k: raw_example[k] for k in ["id", "question", "answer", "type", "supporting_facts"]}
if "supporting_facts" in example:
example["gold_titles"] = set(example["supporting_facts"]["title"])
del example["supporting_facts"]
test.append(example)

self._test = test


if __name__ == '__main__':
if __name__ == "__main__":
from dsp.utils import dotdict

data_args = dotdict(train_seed=1, train_size=16, eval_seed=2023, dev_size=200*5, test_size=0)
data_args = dotdict(train_seed=1, train_size=16, eval_seed=2023, dev_size=200 * 5, test_size=0)
dataset = HotPotQA(**data_args)

print(dataset)
Expand Down
9 changes: 5 additions & 4 deletions dspy/datasets/math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import re
import random
import re


class MATH:
def __init__(self, subset):
import dspy
from datasets import load_dataset

import dspy

ds = load_dataset("lighteval/MATH", subset)

# NOTE: Defaults to sub-splitting MATH's 'test' split into train/dev/test, presuming that current
Expand All @@ -21,7 +22,7 @@ def __init__(self, subset):

size = min(350, len(dataset) // 3)
random.Random(0).shuffle(dataset)
self.train, self.dev, self.test = dataset[:size], dataset[size:2*size], dataset[2*size:]
self.train, self.dev, self.test = dataset[:size], dataset[size : 2 * size], dataset[2 * size :]

def metric(self, example, pred, trace=None):
try:
Expand All @@ -36,7 +37,7 @@ def extract_answer(s):
start = s.find("\\boxed{")
if start == -1:
return None

idx = start + len("\\boxed{")
brace_level = 1

Expand Down

0 comments on commit 474d496

Please sign in to comment.