diff --git a/pyproject.toml b/pyproject.toml index 126d66244..df4ff39da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ multilingual = [ "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer ] -math = ["latex2sympy2_extended>=0.9.1"] +math = ["latex2sympy2_extended>=0.9.3"] [project.urls] Homepage = "https://github.com/huggingface/lighteval" diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index af5784bc2..2364b470e 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -34,7 +34,6 @@ import torch from datasets import Dataset, load_dataset from datasets.utils.metadata import MetadataConfigs -from fsspec import url_to_fs from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HFSummaryWriter, hf_hub_url from lighteval.logging.info_loggers import ( @@ -53,6 +52,11 @@ if is_nanotron_available(): from nanotron.config import GeneralArgs # type: ignore +try: + from fsspec import url_to_fs +except ImportError: + from fsspec.core import url_to_fs + class EnhancedJSONEncoder(json.JSONEncoder): """ @@ -231,9 +235,45 @@ def save_results(self, date_id: str, results_dict: dict): with self.fs.open(output_results_file, "w") as f: f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)) - def save_details(self, date_id: str, details_datasets: dict[str, Dataset]): + def _get_details_sub_folder(self, date_id: str): output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name - output_dir_details_sub_folder = output_dir_details / date_id + if date_id in ["first", "last"]: + # Get all folders in output_dir_details + if not self.fs.exists(output_dir_details): + raise FileNotFoundError(f"Details directory {output_dir_details} does not exist") + + # List all folders and filter out files + folders = [f["name"] for f in self.fs.listdir(output_dir_details) if f["type"] == "directory"] + + if not folders: + raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}") + + # Parse timestamps and get first or last + date_id = max(folders) if date_id == "last" else min(folders) + return output_dir_details / date_id + + def load_details_datasets(self, date_id: str, task_names: list[str]) -> dict[str, Dataset]: + output_dir_details_sub_folder = self._get_details_sub_folder(date_id) + logger.info(f"Loading details from {output_dir_details_sub_folder}") + date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest + details_datasets = {} + for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")): + task_name = Path(file).stem.replace("details_", "").replace(f"_{date_id}", "") + if "|".join(task_name.split("|")[:-1]) not in task_names: + logger.info(f"Skipping {task_name} because it is not in the task_names list") + continue + dataset = load_dataset("parquet", data_files=file, split="train") + details_datasets[task_name] = dataset + + for task_name in task_names: + if not any(task_name.startswith(task_name) for task_name in details_datasets.keys()): + raise ValueError( + f"Task {task_name} not found in details datasets. Check the tasks to be evaluated or the date_id used to load the details ({date_id})." + ) + return details_datasets + + def save_details(self, date_id: str, details_datasets: dict[str, Dataset]): + output_dir_details_sub_folder = self._get_details_sub_folder(date_id) self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True) logger.info(f"Saving details to {output_dir_details_sub_folder}") for task_name, dataset in details_datasets.items(): diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index fe7f98d6f..d8d69f30f 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -67,6 +67,9 @@ def accelerate( # noqa C901 num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -137,6 +140,7 @@ def accelerate( # noqa C901 max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) # TODO (nathan): better handling of model_args diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 2c51fe15f..858cdcde3 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -179,6 +179,9 @@ def inference_endpoint( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -247,6 +250,7 @@ def inference_endpoint( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, @@ -292,6 +296,9 @@ def tgi( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -355,6 +362,7 @@ def tgi( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, @@ -400,6 +408,9 @@ def litellm( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -464,6 +475,7 @@ def litellm( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) pipeline = Pipeline( tasks=tasks, diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py index 89311b5ae..d063c3fa8 100644 --- a/src/lighteval/main_vllm.py +++ b/src/lighteval/main_vllm.py @@ -63,6 +63,9 @@ def vllm( num_fewshot_seeds: Annotated[ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + load_responses_from_details_date_id: Annotated[ + Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -124,6 +127,7 @@ def vllm( max_samples=max_samples, use_chat_template=use_chat_template, system_prompt=system_prompt, + load_responses_from_details_date_id=load_responses_from_details_date_id, ) if model_args.endswith(".yaml"): diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 577934e9d..4462ffac1 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -191,6 +191,7 @@ def multilingual_extractive_match_metric( pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), aggregation_function: Callable[[list[float]], float] = max, fallback_mode: Literal["no_fallback", "first_match"] = "first_match", + extraction_mode: Literal["first_match", "any_match"] = "any_match", precision: int = 6, ) -> SampleLevelMetric: """Creates a language-aware extractive match metric that extracts answers from the model's output. @@ -215,6 +216,10 @@ def multilingual_extractive_match_metric( How to perform extraction. Defaults to "first_match". - "no_fallback": Only use first successfully parsed matches - "first_match": Use the first successfully parsed match + first match irregardless the parsing success + extraction_mode: Literal["first_match", "any_match"] + - "first_match": Only tries to extract the first regex match if it fails no other matches are tried + - "any_match": Tries to extract any regex match + precision: int Number of decimal places to use when comparing numerical values. Defaults to 6. @@ -240,9 +245,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language) extracted_predictions = [ - extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions + extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode) + for pred in predictions + ] + extracted_golds = [ + extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds ] - extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds] # Assert on empty gold and warn on empty pred if any(len(g) == 0 for g in extracted_golds): diff --git a/src/lighteval/metrics/imports/__init__.py b/src/lighteval/metrics/imports/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/metrics/imports/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/metrics/utils/__init__.py b/src/lighteval/metrics/utils/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/metrics/utils/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index 01d2fc102..77540c175 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -21,10 +21,10 @@ # SOFTWARE. import re -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from itertools import groupby -from typing import Literal, Sequence +from typing import Any, Literal, Sequence import sympy from sympy import Basic, MatrixBase, Number @@ -39,17 +39,33 @@ from lighteval.utils.timeout import timeout +@requires_latex2sympy2_extended +def latex_normalization_config_default_factory(): + from latex2sympy2_extended.latex2sympy2 import NormalizationConfig + + return NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=True, + nits=True, + boxed=True, + equations=True, + ) + + @dataclass(frozen=True) class LatexExtractionConfig: """Config for extracting latex from the prediction. Attributes: try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is" - enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions + boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...) + normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction """ try_extract_without_anchor: bool = True - enforce_boxed_match: bool = True + boxed_match_priority: int = 55 + normalization_config: Any = field(default_factory=latex_normalization_config_default_factory) @dataclass(frozen=True) @@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) -> if latex_config.try_extract_without_anchor: regexes.append((latex_re, 300)) - # This ensures that boxed is matched right after the final answer xxxx - if latex_config.enforce_boxed_match: - regexes.append((latex_boxed, 55)) + if latex_config.boxed_match_priority >= 0: + regexes.append((latex_boxed, latex_config.boxed_match_priority)) return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes] @@ -387,6 +402,7 @@ def extract_target_from_pred( pred: str, target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]], fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback", + extraction_mode: Literal["first_match", "any_match"] = "any_match", ): """Extracts targets from a prediction string using regex patterns. Returns first sucesffuly extracted match. @@ -397,6 +413,9 @@ def extract_target_from_pred( fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback". - "no_fallback": Return only successfully parsed match - "first_match": Additionaly Include the first string match no matter how parsing finished + extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match". + - "first_match": Only tries to extract the first match + - "any_match": Tries to extract any match Returns: list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match" @@ -410,6 +429,7 @@ def extract_target_from_pred( for target_patterns, target_type in target_res for pattern, priority in target_patterns ] + match_found = False # Group patterns by priority using itertools.groupby for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]): @@ -426,6 +446,7 @@ def extract_target_from_pred( # Try to extract from each match, starting from rightmost for match, _, _, target_type in matches_with_pos: extracted_match, str_fallback = extract_match(match, target_type) + match_found = True if str_fallback: fallbacks.append(str_fallback) @@ -434,8 +455,11 @@ def extract_target_from_pred( extracted_predictions.append(extracted_match) break + if extraction_mode == "first_match": + break + # If we found something and we're in first_match mode, stop processing other priorities - if extracted_predictions: + if extracted_predictions or (match_found and extraction_mode == "first_match"): break if fallback_mode == "first_match" and fallbacks: diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 6a40d2801..0e6282ef5 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -20,9 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import ast import collections import os import random +import re import shutil from contextlib import nullcontext from dataclasses import dataclass, field @@ -30,14 +32,21 @@ from enum import Enum, auto import numpy as np +from tqdm import tqdm from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.utils.metric_utils import MetricCategory from lighteval.models.model_loader import TransformersModel, load_model -from lighteval.models.model_output import ModelResponse +from lighteval.models.model_output import ( + GenerativeMultiturnResponse, + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, + ModelResponse, +) from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks from lighteval.tasks.registry import Registry, taskinfo_selector -from lighteval.tasks.requests import SampleUid +from lighteval.tasks.requests import RequestType, SampleUid from lighteval.utils.imports import ( NO_ACCELERATE_ERROR_MSG, NO_NANOTRON_ERROR_MSG, @@ -95,6 +104,7 @@ class PipelineParameters: max_samples: int | None = None use_chat_template: bool = False system_prompt: str | None = None + load_responses_from_details_date_id: str | None = None def __post_init__(self): # noqa C901 if self.launcher_type == ParallelismManager.ACCELERATE: @@ -245,7 +255,17 @@ def evaluate(self): config=self.model_config, ) - sample_id_to_responses = self._run_model() + if self.pipeline_parameters.load_responses_from_details_date_id: + try: + sample_id_to_responses = self._load_responses_from_details() + except FileNotFoundError as e: + logger.warning( + f"No responses found for {self.pipeline_parameters.load_responses_from_details_date_id} in details directory: {e}. Running model instead." + ) + sample_id_to_responses = self._run_model() + else: + sample_id_to_responses = self._run_model() + self._compute_metrics(sample_id_to_responses) if self.is_main_process(): @@ -261,6 +281,158 @@ def evaluate(self): except OSError: pass + def _unpack(self, x): + if isinstance(x, str): + return x + elif isinstance(x, (list, tuple)): + return self._unpack(x[0]) + else: + raise ValueError(f"Unknown type {type(x)} of prediction {x}") + + def _parse_tensor_string(self, tensor_string): + """ + Convert a string containing PyTorch-like `tensor([...], device='cuda:0', ...)` + into a Python list (or nested lists) of numbers. + + Example: + "[tensor([1, 2, 3], device='cuda:0'), tensor([[4,5],[6,7]], dtype=torch.int64)]" + -> [[1, 2, 3], [[4, 5], [6, 7]]] + """ + + # Regex explanation: + # - tensor\(\s*: Matches "tensor(" (possibly with spaces after), literally. + # - (.*?): Captures everything lazily into group(1), until the first subsequent part matches. + # We rely on the next pattern to anchor the end of this capture. + # - \): The literal closing parenthesis, but we anchor the match by ignoring + # further arguments (device=..., dtype=..., etc.) inside. + # + # The tricky part: a tensor might look like + # tensor([ ... ], device='cuda:0', dtype=torch.int64) + # so the bracket portion is `[ ... ]`, but it can have newlines, etc. + # + # We'll handle that by first capturing the entire content up to the final parenthesis, + # then parse out the bracket portion. This can be done in a function-based re.sub. + + pattern = re.compile( + r"tensor\s*\(\s*(.*?)\s*\)", # capture everything inside tensor(...) + flags=re.DOTALL, + ) + + def tensor_replacer(match): + inside = match.group(1).strip() + # `inside` might look like: [1, 2, 3], device='cuda:0' + # or: + # [ + # 1, 2, 3, + # 4, 5, ... + # ], device='cuda:0', dtype=torch.int64 + # + # 1) Extract the bracketed array portion: the first [ ... ] block + # which might be multi-line. We'll use another regex for that. + + # We look for the bracketed portion from the first '[' to its matching ']'. + # Because the inside can be multi-line, we use DOTALL. But we still need + # to ensure we don't accidentally go beyond the matching bracket. + # + # A robust approach to properly match brackets can be done with a small parser, + # but for typical well-formed strings, a lazy match of the form + # r"\[.*?\]" DOTALL often suffices, assuming no nested brackets inside. + + bracket_pattern = re.compile(r"\[.*?\]", re.DOTALL) + bracket_match = bracket_pattern.search(inside) + if not bracket_match: + # If we fail to find a bracket, just return something safe. + # This means the string didn't match the expected format. + return "[]" + + # The bracketed portion (e.g. "[1, 2, 3\n, 4]"). + bracketed_content = bracket_match.group(0) + + # Return just the bracketed content, + # effectively replacing "tensor(...)" with "[...]". + return bracketed_content + + # Step 1: Replace every `tensor(...)` occurrence with just the bracketed list. + processed = pattern.sub(tensor_replacer, tensor_string) + + # Step 2: Now we can safely parse the result with literal_eval. + # If there's still something weird, it may throw ValueError. + try: + return ast.literal_eval(processed) + except Exception as e: + raise ValueError(f"Failed to parse after preprocessing. " f"Processed string:\n{processed}\n\nError: {e}") + + def _load_responses_from_details(self): + logger.info("--- LOADING RESPONSES FROM DETAILS ---") + sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list) + + request_types = list(self.requests.keys()) + if len(request_types) > 1: + raise ValueError( + "Loading responses from details when there are multiple request types is currently not supported" + ) + model_response_type = self._get_model_response_type(request_types[0]) + + details_datasets = self.evaluation_tracker.load_details_datasets( + self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list + ) + + for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"): + task: LightevalTask = self._get_task(task_name) + num_samples = len(set(dataset["specifics"])) + max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples + if num_samples > max_samples: + logger.warning( + f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}" + ) + num_samples = self.pipeline_parameters.max_samples + + predictions = [self._unpack(ast.literal_eval(p)) for p in dataset["predictions"][:num_samples]] + input_tokens = [self._parse_tensor_string(t) for t in dataset["input_tokens"][:num_samples]] + cont_tokens = [self._parse_tensor_string(t) for t in dataset["cont_tokens"][:num_samples]] + truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]] + padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]] + + if model_response_type == GenerativeResponse: + logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]] + + for metric_category, has_metric_category in task.has_metric_category.items(): + if not has_metric_category: + continue + + for idx in range(num_samples): + kwargs = { + "result": predictions[idx], + "input_tokens": input_tokens[idx], + "generated_tokens": cont_tokens[idx], + "truncated_tokens_count": truncated[idx], + "padded_tokens_count": padded[idx], + } + if model_response_type == GenerativeResponse: + kwargs["logits"] = logits[idx] + + response = model_response_type(**kwargs) + sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response] + return sample_id_to_responses + + def _get_model_response_type(self, request_type): + if request_type == RequestType.LOGLIKELIHOOD: + model_response_type = LoglikelihoodResponse + elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN: + model_response_type = LoglikelihoodSingleTokenResponse + elif request_type == RequestType.LOGLIKELIHOOD_ROLLING: + model_response_type = LoglikelihoodResponse + elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN: + model_response_type = GenerativeMultiturnResponse + elif request_type == RequestType.GREEDY_UNTIL: + model_response_type = GenerativeResponse + else: + raise ValueError( + f"Loading responses from details for request type {request_type} is currently not supported" + ) + + return model_response_type + def _run_model(self): # Running all requests depending on the model call type (log likelihood, generative, ...) # to be able to batch them @@ -283,6 +455,10 @@ def _run_model(self): return sample_id_to_responses + def _get_task(self, task_name: str): + short_task_name = task_name.rsplit("|", 1)[0] + return self.task_dict[short_task_name] + def _compute_metrics(self, sample_id_to_responses): # To compute the metrics we first group the samples and task and then by metrics. # This way we can batch the metrics computation for each task and metric category @@ -307,8 +483,7 @@ def _compute_metrics(self, sample_id_to_responses): task_metric_category_groups[sample_id.task_name][metric_category]["docs"].append(self.docs[sample_id]) for task_name, samples_per_metric in task_metric_category_groups.items(): - short_task_name = task_name.rsplit("|", 1)[0] - task: LightevalTask = self.task_dict[short_task_name] + task: LightevalTask = self._get_task(task_name) for metric_category, samples in samples_per_metric.items(): sample_ids = samples["ids"] diff --git a/src/lighteval/tasks/__initi__.py b/src/lighteval/tasks/__initi__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/tasks/__initi__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 59254a971..d6a7ec498 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -6210,7 +6210,7 @@ evaluation_splits=["validation"], few_shots_split=None, few_shots_select=None, - generation_size=-1, + generation_size=1, metric=[ Metrics.exact_match, Metrics.quasi_exact_match, diff --git a/src/lighteval/tasks/multilingual/__init__.py b/src/lighteval/tasks/multilingual/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/tasks/multilingual/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/tasks/multilingual/utils/__init__.py b/src/lighteval/tasks/multilingual/utils/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/tasks/multilingual/utils/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/tasks/templates/__init__.py b/src/lighteval/tasks/templates/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/tasks/templates/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/tasks/templates/utils/__init__.py b/src/lighteval/tasks/templates/utils/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/tasks/templates/utils/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 756285e62..fbac41c4d 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -980,7 +980,7 @@ def __getattribute__(self, name: str) -> str: language=Language.UKRAINIAN, question_word="питання", answer="відповідь", - confirmation_word="вірно", + confirmation_word="правильно", yes="так", no="ні", also="також", @@ -998,6 +998,7 @@ def __getattribute__(self, name: str) -> str: sentence_space=" ", colon=":", semicolon=";", + indices=["А", "Б", "В", "Г", "Д", "Е"], ), Language.URDU: TranslationLiterals( language=Language.URDU, diff --git a/src/lighteval/utils/__init__.py b/src/lighteval/utils/__init__.py new file mode 100644 index 000000000..a732db8d0 --- /dev/null +++ b/src/lighteval/utils/__init__.py @@ -0,0 +1,21 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE.