Skip to content

Commit

Permalink
Add helpdesk call summarization scenario (#3303)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Jan 31, 2025
1 parent 59dcfb1 commit f6a9856
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import re
from typing import Any

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.annotation.annotator import Annotator
from helm.clients.auto_client import AutoClient
from helm.common.request import Request
from helm.proxy.retry import NonRetriableException


class AnnotatorResponseParseFailure(NonRetriableException):
def __init__(self, response_text: str, **kwargs):
self.response_text = response_text
super().__init__(kwargs)


class HelpdeskCallSummarizationAnnotator(Annotator):
"""The Helpdesk Call Summarization autograder."""

name = "helpdesk_call_center_summarization"

PROMPT_TEMPLATE = """You are an expert evaluator. Your task is to evaluate the quality of a model-generated summary of a helpdesk call transcript.
The helpdesk call transcript and summary are provided below, delineated with start and end tags:
<call_transcript>
{{QUESTION}}
</call_transcript>
<summary>
{{PRED}}
</summary>
Evaluate the summary based on the following criteria:
- Conciseness: A high-quality summary should effectively convey the most important information from the original source while keeping the length brief.
- Relevance: The information presented in the summary should be relevant to the main topic.
- Coherence: A good summary should have a clear structure and flow of ideas that make it easy to understand and follow.
- Accuracy: The summary's information should be factually correct and should not contain false or misleading information.
Think step by step, then score the summary. Your reasoning should be less than 100 words. The score should be a single number between 1 to 10 inclusive.
Please respond with your output and reasoning in the following format, your reasoning within <reasoning></reasoning> tags and your score within <score></score> tags, without any other output:
<reasoning>INSERT_YOUR_REASONING_HERE</reasoning>
<score>INSERT_YOUR_SCORE_HERE</score>
""" # noqa: E501

PATTERN = r"^\s*reason:(.*)##(.*)"

def __init__(self, auto_client: AutoClient):
self._auto_client = auto_client

def annotate(self, request_state: RequestState) -> Any:
assert request_state.result
assert len(request_state.result.completions) == 1
prediction_text = request_state.result.completions[0].text

question_text = request_state.instance.input.text

annotator_prompt = self.PROMPT_TEMPLATE.replace("{{QUESTION}}", question_text).replace(
"{{PRED}}", prediction_text
)
annotator_request = Request(
model="openai/gpt-4o-2024-05-13",
model_deployment="openai/gpt-4o-2024-05-13",
prompt=annotator_prompt,
temperature=0.0,
max_tokens=512,
)
annotator_response = self._auto_client.make_request(annotator_request)
if not annotator_response.success:
raise Exception(f"Annotation request failed: {annotator_response.error}")
assert len(annotator_response.completions) == 1
annotator_response_text = annotator_response.completions[0].text
# fuzzy match regex check, allows for different casing, or forgetting / in end tag
reasoning_match = re.search(
r"<\s*reasoning\s*>(.*?)<\/?\s*reasoning\s*>", annotator_response_text, re.DOTALL | re.IGNORECASE
)
score_match = re.search(
r"<\s*score\s*>(.*?)<\/?\s*score\s*>", annotator_response_text, re.DOTALL | re.IGNORECASE
)
if not reasoning_match or not score_match:
raise AnnotatorResponseParseFailure(
message=f"Could not parse markup in raw response: '{annotator_response_text}'",
response_text=annotator_response_text,
)
reasoning = reasoning_match.group(1).strip()
try:
score = float(score_match.group(1).strip())
except ValueError:
raise AnnotatorResponseParseFailure(
message=f"Could not parse score as float from raw request: '{annotator_response_text}'",
response_text=annotator_response_text,
)

return {"reasoning": reasoning, "score": score}
55 changes: 55 additions & 0 deletions src/helm/benchmark/run_specs/call_center_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,58 @@ def get_call_center_summarization_key_points_recall_spec() -> RunSpec:
annotators=annotator_specs,
groups=["call_center_summarization_key_points_recall"],
)


@run_spec_function("helpdesk_call_summarization")
def get_helpdesk_call_summarization_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.helpdesk_call_summarization_scenario.HelpdeskCallSummarizationScenario",
)
annotator_specs = annotator_specs = [
AnnotatorSpec(
class_name="helm.benchmark.annotation.helpdesk_call_summarization_annotator.HelpdeskCallSummarizationAnnotator" # noqa: E501
)
]

instructions = "The following is a call transcript of a call between a compnay's employee and the company's IT helpdesk. Summarize the call transcript in under 100 words." # noqa: E501

adapter_spec = AdapterSpec(
method=ADAPT_GENERATION,
instructions=instructions,
input_prefix="### Call Transcript\n",
input_suffix="",
output_prefix="",
output_suffix="",
max_train_instances=0,
temperature=0.0,
max_tokens=512,
num_outputs=1,
)

# annotator_specs = annotator_specs = [
# AnnotatorSpec(class_name="helm.benchmark.annotation.call_center_annotator.CallCenterSummarizationAnnotator")
# ]
annotation_metric_specs = [
MetricSpec(
class_name="helm.benchmark.metrics.annotation_metrics.AnnotationLikertScaleMetric",
args={
"annotator_name": "helpdesk_call_center_summarization",
"key": "score",
"min_score": 1,
"max_score": 10,
},
)
]

metric_specs = get_basic_metric_specs([]) + annotation_metric_specs

group = "helpdesk_call_summarization"

return RunSpec(
name="helpdesk_call_summarization",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
annotators=annotator_specs,
groups=[group],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import csv
import os
from typing import List

from helm.benchmark.scenarios.scenario import (
Scenario,
Instance,
TEST_SPLIT,
Input,
)


_DATA_DIRRECTORY_PATH = "restricted/helpdesk_call_summarization/HELM Sample Transcripts_20241221_0045"


class HelpdeskCallSummarizationScenario(Scenario):
"""Helpdesk call summarization."""

name = "helpdesk_call_summarization"
description = "Helpdesk call summarization."
tags = ["helpdesk_call_center"]

def get_instances(self, output_path: str) -> List[Instance]:
instances: List[Instance] = []
for file_name in os.listdir(_DATA_DIRRECTORY_PATH):
if not file_name.endswith(".csv") or not file_name.startswith("Call1-"):
continue
file_path = os.path.join(_DATA_DIRRECTORY_PATH, file_name)
with open(file_path) as f:
csv_reader = csv.reader(f)
prompt_lines = [f"{row[0]}: {row[4]}" for row in csv_reader]
prompt = "\n".join(prompt_lines)
instance_id = file_name.removeprefix("Call1-").removesuffix(".csv")
input = Input(text=prompt)
instance = Instance(id=instance_id, input=input, references=[], split=TEST_SPLIT)
instances.append(instance)
return instances
35 changes: 31 additions & 4 deletions src/helm/benchmark/static/schema_call_center.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ metrics:
short_display_name: Recall
description: How many key items were recalled
lower_is_better: false

- name: annotation_helpdesk_call_center_summarization_score
display_name: Score
short_display_name: Score
description: Score
lower_is_better: false

############################################################
perturbations: []

Expand All @@ -108,6 +115,8 @@ metric_groups:
display_name: Summarization
hide_win_rates: true
metrics:
- name: annotation_helpdesk_call_center_summarization_score
split: ${main_split}
- name: annotation_call_center_summarization_faithfulness
split: ${main_split}
- name: annotation_call_center_summarization_relevance
Expand Down Expand Up @@ -158,10 +167,11 @@ run_groups:
description: Scenarios representating realistic tasks from the call center.
category: All scenarios
subgroups:
- call_center_summarization
- call_center_summarization_real_call_transcripts
- call_center_summarization_pairwise_comparison
- call_center_summarization_key_points_recall
- helpdesk_call_summarization
# - call_center_summarization
# - call_center_summarization_real_call_transcripts
# - call_center_summarization_pairwise_comparison
# - call_center_summarization_key_points_recall

- name: call_center_summarization
display_name: Summarization
Expand All @@ -180,6 +190,23 @@ run_groups:
when: "?"
language: English

- name: helpdesk_call_summarization
display_name: Helpdesk Call summarization
description: Helpdesk Call summarization
metric_groups:
# - accuracy
- summarization_metrics
- efficiency
- general_information
environment:
main_split: test
taxonomy:
task: summarization
what: n/a
who: n/a
when: "?"
language: English

- name: call_center_summarization_real_call_transcripts
display_name: Summarization (Real)
description: Summarization with real call transcripts
Expand Down

0 comments on commit f6a9856

Please sign in to comment.