Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spider 1.0 scenario #3300

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/helm/benchmark/run_specs/sql_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@run_spec_function("bird_sql")
def get_bird_sql_dev() -> RunSpec:
def get_bird_sql_dev_run_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.bird_sql_scenario.BIRDSQLScenario")

adapter_spec = get_generation_adapter_spec(
Expand All @@ -22,3 +22,23 @@ def get_bird_sql_dev() -> RunSpec:
metric_specs=get_exact_match_metric_specs(),
groups=["bird_sql"],
)


@run_spec_function("spider")
def get_spider_run_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.spider_scenario.SpiderScenario")

adapter_spec = get_generation_adapter_spec(
input_noun=None,
output_noun=None,
max_tokens=1024,
stop_sequences=[],
)

return RunSpec(
name="spider",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["spider"],
)
81 changes: 81 additions & 0 deletions src/helm/benchmark/scenarios/spider_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import os
from typing import Dict, List

from filelock import FileLock

from helm.common.general import ensure_directory_exists, ensure_file_downloaded, shell
from helm.common.hierarchical_logger import hlog
from helm.benchmark.scenarios.bird_sql_scenario_helper import ( # type: ignore
cot_wizard,
generate_comment_prompt,
generate_schema_prompt,
)
from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
Scenario,
Instance,
Reference,
VALID_SPLIT,
Input,
Output,
)


def _ensure_file_unzipped(source_path: str, target_path: str):
with FileLock(f"{target_path}.lock"):
if os.path.exists(target_path):
hlog(f"Not decompressing {source_path} because {target_path} already exists")
return
tmp_path = target_path + ".tmp"
ensure_directory_exists(tmp_path)
shell(["unzip", source_path, "-d", tmp_path])
shell(["mv", tmp_path, target_path])


class SpiderScenario(Scenario):
"""Spider 1.0"""

name = "spider"
description = "spider"
tags = ["sql"]

def get_instances(self, output_path: str) -> List[Instance]:
data_parent_path = os.path.join(output_path, "data")
ensure_file_downloaded(
"https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J&export=download&confirm=t",
data_parent_path,
unpack=True,
unpack_type="unzip",
)
data_root_path = os.path.join(data_parent_path, "spider_data")
databases_root_path = os.path.join(data_root_path, "test_database")

database_schema_prompts: Dict[str, str] = {}
for database_name in os.listdir(databases_root_path):
database_path = os.path.join(databases_root_path, database_name, f"{database_name}.sqlite")
if not os.path.exists(database_path):
# Ignore stray ".DS_Store" directory
continue

database_schema_prompt = generate_schema_prompt(database_path, num_rows=None)
database_schema_prompts[database_name] = database_schema_prompt

instances: List[Instance] = []
dataset_path = os.path.join(data_root_path, "test.json")
dataset = json.load(open(dataset_path, "r"))
for row in dataset:
database_id: str = row["db_id"]
question: str = row["question"]
gold_sql: str = row["query"]

schema_prompt = database_schema_prompts[database_id]
comment_prompt = generate_comment_prompt(question, None)
combined_prompt = schema_prompt + "\n\n" + comment_prompt + cot_wizard() + "\nSELECT "
instance = Instance(
input=Input(text=combined_prompt),
references=[Reference(output=Output(text=gold_sql), tags=[CORRECT_TAG])],
split=VALID_SPLIT,
)
instances.append(instance)
return instances
25 changes: 13 additions & 12 deletions src/helm/benchmark/static/schema_sql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,29 @@ metric_groups:

############################################################
run_groups:
- name: financial_scenarios
display_name: Financial Scenarios
description: Scenarios for the financial domain
- name: text_to_sql_scenarios
display_name: Text-to-SQL Scenarios
description: Text-to-SQL Scenarios
category: All scenarios
subgroups:
- czech_bank_qa
- spider
- bird_sql

- name: czech_bank_qa
display_name: CzechBankQA
description: The CzechBankQA
- name: spider
display_name: Spider 1.0 (Test)
description: Spider 1.0 (Test)
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: error_rate
main_split: test
main_name: quasi_exact_match
main_split: valid
taxonomy:
task: text-to-SQL
what: queries from financial experts
who: financial experts
when: "1999"
what: databases from various domains
who: expert data scientists
when: "?"
language: English

- name: bird_sql
Expand Down