From 905c81ee53045376e0837f042189f7a825167752 Mon Sep 17 00:00:00 2001 From: Roni Friedman-Melamed Date: Thu, 7 Nov 2024 10:37:57 +0200 Subject: [PATCH] simplify run_mmlu return value Signed-off-by: Roni Friedman-Melamed --- src/instructlab/eval/mmlu.py | 10 +++------- src/instructlab/eval/unitxt.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/instructlab/eval/mmlu.py b/src/instructlab/eval/mmlu.py index acde8de5..66053568 100644 --- a/src/instructlab/eval/mmlu.py +++ b/src/instructlab/eval/mmlu.py @@ -142,7 +142,7 @@ def run(self, server_url: str | None = None) -> tuple: agg_score: float = 0.0 results = self._run_mmlu(server_url) - for task, result in results.items(): + for task, result in results['results'].items(): agg_score += float(result["acc,none"]) individual_scores[task] = { "score": float(result["acc,none"]), @@ -154,7 +154,7 @@ def run(self, server_url: str | None = None) -> tuple: return overall_score, individual_scores def _run_mmlu( - self, server_url: str | None = None, return_all_results: bool = False + self, server_url: str | None = None ) -> dict: if server_url is not None: # Requires lm_eval >= 0.4.4 @@ -179,11 +179,7 @@ def _run_mmlu( device=self.device, task_manager=tm, ) - if return_all_results: - results = mmlu_output - else: - results = mmlu_output["results"] - return results + return mmlu_output # This method converts general errors from simple_evaluate # into a more user-understandable error diff --git a/src/instructlab/eval/unitxt.py b/src/instructlab/eval/unitxt.py index 55426667..dade3021 100644 --- a/src/instructlab/eval/unitxt.py +++ b/src/instructlab/eval/unitxt.py @@ -90,7 +90,7 @@ def run(self, server_url: str | None = None) -> tuple: self.prepare_unitxt_files() logger.debug(locals()) os.environ["TOKENIZERS_PARALLELISM"] = "true" - results = self._run_mmlu(server_url=server_url, return_all_results=True) + results = self._run_mmlu(server_url=server_url) taskname = self.tasks[0] global_scores = results["results"][taskname] global_scores.pop("alias")