diff --git a/src/helm/benchmark/annotation/omni_math_annotator.py b/src/helm/benchmark/annotation/omni_math_annotator.py index 149402ac393..77c81aaeeee 100644 --- a/src/helm/benchmark/annotation/omni_math_annotator.py +++ b/src/helm/benchmark/annotation/omni_math_annotator.py @@ -61,9 +61,12 @@ def annotate(self, request_state: RequestState) -> Any: info = parse_report(annotator_response_text) - correctness = info.get("Equivalence Judgement", "FALSE") + equivalence_judgement = info.get("Equivalence Judgement", "") + student_final_answer = info.get("Student Final Answer", "") + justification = info.get("Justification", "").strip().removesuffix("=== report over ===").strip() - if correctness == "TRUE": - return {"prompt_text": annotator_prompt, "correctness": 1.0} - else: - return {"prompt_text": annotator_prompt, "correctness": 0.0} + return { + "student_final_answer": student_final_answer, + "equivalence_judgement": equivalence_judgement, + "justification": justification, + } diff --git a/src/helm/benchmark/metrics/omni_math_metrics.py b/src/helm/benchmark/metrics/omni_math_metrics.py index c63c9f4020b..48ff67ac33c 100644 --- a/src/helm/benchmark/metrics/omni_math_metrics.py +++ b/src/helm/benchmark/metrics/omni_math_metrics.py @@ -19,7 +19,7 @@ def evaluate_generation( eval_cache_path: str, ) -> List[Stat]: assert request_state.annotations - score = request_state.annotations["omni_math"]["correctness"] + score = request_state.annotations["omni_math"]["equivalence_judgement"].strip().upper() == "TRUE" return [ Stat(MetricName("omni_math_accuracy")).add(score), ]