From cb35beae9f1bf8133f840de1ea5ad840e37c1e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 5 Feb 2025 12:34:18 +0100 Subject: [PATCH 1/4] Sync Math-verify (#535) * update extraction match to reflect newest math-verify * revert symbols, improve sets handling * rm todo * fmt + remove empty excepts + bump l2s * fmt * docstring --- pyproject.toml | 2 +- src/lighteval/metrics/dynamic_metrics.py | 19 +- .../metrics/utils/extractive_match_utils.py | 237 ++++++++++++---- .../metrics/utils/math_comparison.py | 268 ++++++++++++++---- .../templates/utils/translation_literals.py | 6 + tests/metrics/test_extractive_match.py | 171 ++++++++++- 6 files changed, 577 insertions(+), 126 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df4ff39da..9ff7050f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ multilingual = [ "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer ] -math = ["latex2sympy2_extended>=0.9.3"] +math = ["latex2sympy2_extended==1.0.4"] [project.urls] Homepage = "https://github.com/huggingface/lighteval" diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 51f749d0b..34f69a8bd 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -193,6 +193,7 @@ def multilingual_extractive_match_metric( fallback_mode: Literal["no_fallback", "first_match"] = "first_match", extraction_mode: Literal["first_match", "any_match"] = "any_match", precision: int = 6, + timeout_seconds: int = 5, ) -> SampleLevelMetric: """Creates a language-aware extractive match metric that extracts answers from the model's output. @@ -222,6 +223,8 @@ def multilingual_extractive_match_metric( precision: int Number of decimal places to use when comparing numerical values. Defaults to 6. + timeout_seconds: int + Timeout for the extraction (each attempt) and comparison. Defaults to 5. Returns: A sample level metric that extracts and compares mathematical expressions. @@ -245,11 +248,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language) extracted_predictions = [ - extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode) + extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds) for pred in predictions ] extracted_golds = [ - extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds + extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds) + for gold in golds ] # Assert on empty gold and warn on empty pred @@ -265,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc # We have to use timeout because the sypmy to str conversion can be very slow try: add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds) - except: # noqa: E722 + except Exception: # noqa: E722 logger.warning("Timeout when adding extracted predictions and golds to specific") return aggregation_function( [ - (1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0) + ( + 1.0 + if any( + compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds) + for gold in extracted_golds + ) + else 0.0 + ) for pred in extracted_predictions ] ) diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index 77540c175..444704108 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -21,13 +21,13 @@ # SOFTWARE. import re -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from functools import lru_cache from itertools import groupby from typing import Any, Literal, Sequence import sympy -from sympy import Basic, MatrixBase, Number +from sympy import Basic, FiniteSet, MatrixBase, Number from sympy.parsing import parse_expr from lighteval.metrics.utils.math_comparison import should_treat_as_complex @@ -48,7 +48,7 @@ def latex_normalization_config_default_factory(): units=True, malformed_operators=True, nits=True, - boxed=True, + boxed="all", equations=True, ) @@ -159,34 +159,95 @@ def lazy_expr_regex(expr_config: ExprExtractionConfig, language: Language) -> li return [(re.compile(pattern), priority) for pattern, priority in regexes] +def make_latex_env_pattern(prefix: str = "", context: Literal["boxed", "plain"] = "plain") -> str: + """Creates a LaTeX environment pattern with uniquely prefixed group names. + + Args: + prefix (str): Prefix to add to group names to make them unique + context (Literal["boxed", "plain"]): Type of content to match inside the environments + - "boxed": Match environments containing \boxed{...} + - "plain": Match any LaTeX content + + Returns: + str: Regex pattern for matching LaTeX environments with percent suffix + """ + percent_re_group = rf"(?P<{prefix}percent>(?:\\?%|[Pp]ercent|[Pp]ercentage|[Pp]ct))" + + # Define base content patterns + display_dollar_content = r"(?:[^$]|\$(?!\$))" + # Either \ not followed by ] or everything but \ + display_content_bracket = r"(?:[^\\]|\\(?!\]))" + inline_dollar_content = r"(?:\\[$]|[^\n$])" + inline_content_parenthesis = r"(?:[^\\\n]|\\(?!\)))" + inline_content_bracket = r"[^\n\]\[]" + + if context == "boxed": + # Rewrite patterns to optionally include boxed content + display_dollar_content = ( + rf"{display_dollar_content}*?\\boxed{{{display_dollar_content}+?}}{display_dollar_content}*?" + ) + display_content_bracket = ( + rf"{display_content_bracket}*?\\boxed{{{display_content_bracket}+?}}{display_content_bracket}*?" + ) + inline_dollar_content = ( + rf"{inline_dollar_content}*?\\boxed{{{inline_dollar_content}+?}}{inline_dollar_content}*?" + ) + inline_content_parenthesis = ( + rf"{inline_content_parenthesis}*?\\boxed{{{inline_content_parenthesis}+?}}{inline_content_parenthesis}*?" + ) + inline_content_bracket = ( + rf"{inline_content_bracket}*?\\boxed{{{inline_content_bracket}+?}}{inline_content_bracket}*?" + ) + else: + display_dollar_content = rf"{display_dollar_content}+?" + display_content_bracket = rf"{display_content_bracket}+?" + inline_dollar_content = rf"{inline_dollar_content}+?" + inline_content_parenthesis = rf"{inline_content_parenthesis}+?" + inline_content_bracket = rf"{inline_content_bracket}+?" + + # Build list of regex patterns + patterns = [ + # Display math environments (allow multiline) + rf"(?{display_dollar_content})(?{display_content_bracket})(?{inline_dollar_content})(?{inline_content_parenthesis})(?{inline_content_bracket})\]\s", + ] + if context == "boxed": + # allow also matching plain boxed + patterns.append(rf"(?P<{prefix}latexBoxed>\\boxed{{.+}})") + elif context == "plain": + simple_number = r"-?\d+(?:[.,]\d+)?" + patterns.append(rf"(?P<{prefix}latexFraction>-?\\frac{{{simple_number}}}{{{simple_number}}})") + + # Join patterns with | and wrap in parentheses + latex_env_re = rf"(?:(?:{'|'.join(patterns)})\s*{percent_re_group}?)" + + return latex_env_re + + @lru_cache(maxsize=1) def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) -> list[tuple[re.Pattern[str], int]]: - # Only LaTeX expressions between delimiters - percent_re_group = r"(?P\s*(?:\\?%|[Pp]ercent|[Pp]ercentage|[Pp]ct))" - latex_envs_re = ( - r"(" - r"(?[\s\S]+?)(?[\s\S]+?)(?(?:\\[$]|[^\n$])+?)(?[^\n]+?)(?[^\n$]+?)(?\\boxed{{.+}})\$?{percent_re_group}?" # Boxed number, it's fine to be as greedy as possible as we will find the correct end afterwards - simple_number = r"-?\d+(?:[.,]\d+)?" - latex_fraction = rf"(?P-?\\frac{{{simple_number}}}{{{simple_number}}})\$?{percent_re_group}?" - - translation_literal = TRANSLATION_LITERALS[language] + latex_envs_re = rf"(?:{first_latex_group}{next_groups})" colon_re = rf"[{re.escape(translation_literal.colon)}\:]" - answer_prefix_re = rf"(?i:{translation_literal.answer})" # We first match boxed env, for some reason that's the most common case of output # Then we match the latex with environments, then we try to match the fraction regexes: list[tuple[str, int]] = [] - for latex_re in [latex_envs_re, latex_fraction]: + for latex_re in [latex_envs_re]: if language == Language.ENGLISH: final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{latex_re}\.?\s?I hope" final_answer_prefixed_just_is = rf"(?i:final answer.{{0,100}}?)\s+is\:?\s*{latex_re}" @@ -203,8 +264,20 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) -> if latex_config.try_extract_without_anchor: regexes.append((latex_re, 300)) + # This ensures that boxed is matched right after the final answer xxxx if latex_config.boxed_match_priority >= 0: - regexes.append((latex_boxed, latex_config.boxed_match_priority)) + latex_re_boxed = make_latex_env_pattern(prefix="first_", context="boxed") + next_groups = "".join( + [ + rf"(?:\s*(?:{and_word}|{or_word})\s*{make_latex_env_pattern(f'next{i}_', context='boxed')})?" + for i in range(1, 6) + ] + ) + latex_re_boxed = rf"{latex_re_boxed}{next_groups}" + regexes.append((latex_re_boxed, latex_config.boxed_match_priority)) + # Match plain boxed, the issue with plain boxed is that it's impossible to know where it stops, so if there are + # till last }. We do the actuall extraction in the normalization step. + regexes.append((r"(?P\\boxed{.+})", latex_config.boxed_match_priority)) return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes] @@ -296,21 +369,21 @@ def get_target_type_order(target_type: ExtractionTarget) -> int: # Small cache, to catche repeated calls invalid parsing @lru_cache(maxsize=20) -@timeout(timeout_seconds=5) @requires_latex2sympy2_extended -def parse_latex_with_timeout(latex: str): +def parse_latex_with_timeout(latex: str, timeout_seconds: int): from latex2sympy2_extended.latex2sympy2 import latex2sympy - return latex2sympy(latex, is_real=not should_treat_as_complex(latex), convert_degrees=False) + return timeout(timeout_seconds)(latex2sympy)( + latex, is_real=not should_treat_as_complex(latex), convert_degrees=False, normalization_config=None + ) @lru_cache(maxsize=20) -@timeout(timeout_seconds=5) -def parse_expr_with_timeout(expr: str): - return parse_expr(expr, evaluate=False) +def parse_expr_with_timeout(expr: str, timeout_seconds: int): + return timeout(timeout_seconds)(parse_expr)(expr, evaluate=False) -def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]: +def extract_expr(match: re.Match, timeout_seconds: int) -> tuple[str | sympy.Expr | None, str]: # First combine the number groups = match.groupdict() # Expr group will always exist because every regex has it @@ -338,8 +411,8 @@ def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]: # Remove new lines and spaces if expr: try: - return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**")), expr - except: # noqa: E722 + return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**"), timeout_seconds), expr + except Exception: # noqa: E722 pass return None, expr @@ -348,42 +421,84 @@ def convert_to_pct(number: Number): return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False) -@lru_cache(maxsize=1000) -@timeout(timeout_seconds=5) @requires_latex2sympy2_extended -def extract_latex(match: re.Match) -> tuple[sympy.Expr | str | None, str]: - from latex2sympy2_extended.latex2sympy2 import NormalizationConfig, normalize_latex - - latex = next((val for name, val in match.groupdict().items() if name.startswith("latex") and val), "") - is_percentage = True if match.group("percent") else False - - normalized_latex = normalize_latex( - latex, - NormalizationConfig( - basic_latex=True, - units=True, - malformed_operators=True, - nits=True, - boxed=True, - equations=True, - ), +@lru_cache(maxsize=20) +def extract_latex( + match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int +) -> tuple[sympy.Expr | str | None, str]: + from latex2sympy2_extended.latex2sympy2 import FiniteSet as L2SFiniteSet + from latex2sympy2_extended.latex2sympy2 import normalize_latex + + latex_exprs = [] + latex_strs = [] + + # Get all latex groups (both first_ and nextN_ prefixes) + first_latex_group = next( + ((val, name) for name, val in match.groupdict().items() if name.startswith("first_latex") and val), None ) - try: - parsed_latex = parse_latex_with_timeout(normalized_latex) - if is_percentage: - parsed_latex = convert_to_pct(parsed_latex) - except: # noqa: E722 - return None, normalized_latex - return parsed_latex, normalized_latex + # Get all nextN_ groups + next_latex_groups = [ + next( + ((val, name) for name, val in match.groupdict().items() if name.startswith(f"next{i}_latex") and val), None + ) + for i in range(1, 6) + ] + + all_latex = list(filter(lambda x: x is not None, [first_latex_group] + next_latex_groups)) + + for latex, name in all_latex: + name_without_prefix = name.split("_")[0] + group_name = name.split("_")[1] if len(name.split("_")) > 1 else None + is_percentage = True if match.groupdict().get(f"{name_without_prefix}_percent") else False + + # Use modified config if group name is 'boxed' + config = latex_config.normalization_config + if group_name == "latexBoxed": + config = replace(config, boxed="last") # Use replace to modify single field + + normalized_latex = normalize_latex( + latex, + config=config, + ) + latex_strs.append(normalized_latex) + + try: + parsed_latex = parse_latex_with_timeout(normalized_latex, timeout_seconds=timeout_seconds) + if is_percentage: + parsed_latex = convert_to_pct(parsed_latex) + latex_exprs.append(parsed_latex) + except Exception: # noqa: E722 + latex_exprs.append(None) + pass + + if not latex_exprs: + return None, "" + + # If we have multiple expressions and all of them are parsed, wrap them in a Tuple + if len(latex_exprs) > 1 and all(expr is not None for expr in latex_exprs): + # To handle solution is: 1,2 and 3 + all_elements = [] + for expr in latex_exprs: + if isinstance(expr, FiniteSet): + all_elements.extend(expr.args) + else: + all_elements.append(expr) + return L2SFiniteSet(*all_elements), " and ".join(latex_strs) + + # Otherwise return the single expression + return latex_exprs[0], latex_strs[0] -def extract_match(match: re.Match, target_type: ExtractionTarget) -> tuple[Basic | MatrixBase | str | None, str]: +def extract_match( + match: re.Match, target_type: ExtractionTarget, timeout_seconds: int +) -> tuple[Basic | MatrixBase | str | None, str]: """Extracts the match from the regex match. Args: match (re.Match): The regex match object containing the extracted text target_type (ExtractionTarget): The type of extraction to perform (latex, expression, or indices) + timeout_seconds (int): Maximum time in seconds to spend parsing expressions Returns: tuple[Basic | MatrixBase | str | None, str]: A tuple containing: @@ -391,9 +506,9 @@ def extract_match(match: re.Match, target_type: ExtractionTarget) -> tuple[Basic - The string representation of the extracted text """ if isinstance(target_type, LatexExtractionConfig): - return extract_latex(match) + return extract_latex(match, target_type, timeout_seconds=timeout_seconds) elif isinstance(target_type, ExprExtractionConfig): - return extract_expr(match) + return extract_expr(match, timeout_seconds=timeout_seconds) elif isinstance(target_type, IndicesExtractionConfig): return match.group("indices"), match.group("indices") @@ -403,6 +518,7 @@ def extract_target_from_pred( target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]], fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback", extraction_mode: Literal["first_match", "any_match"] = "any_match", + timeout_seconds: int = 5, ): """Extracts targets from a prediction string using regex patterns. Returns first sucesffuly extracted match. @@ -416,6 +532,7 @@ def extract_target_from_pred( extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match". - "first_match": Only tries to extract the first match - "any_match": Tries to extract any match + timeout_seconds (int, optional): Maximum time in seconds to spend parsing each expression. Defaults to 5. Returns: list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match" @@ -445,7 +562,7 @@ def extract_target_from_pred( # Try to extract from each match, starting from rightmost for match, _, _, target_type in matches_with_pos: - extracted_match, str_fallback = extract_match(match, target_type) + extracted_match, str_fallback = extract_match(match, target_type, timeout_seconds) match_found = True if str_fallback: diff --git a/src/lighteval/metrics/utils/math_comparison.py b/src/lighteval/metrics/utils/math_comparison.py index 483d1d450..e90f53f7b 100644 --- a/src/lighteval/metrics/utils/math_comparison.py +++ b/src/lighteval/metrics/utils/math_comparison.py @@ -25,7 +25,9 @@ from itertools import product from sympy import ( + And, Basic, + E, Eq, FiniteSet, Float, @@ -41,10 +43,15 @@ StrictGreaterThan, StrictLessThan, Symbol, + Tuple, + default_sort_key, + ordered, simplify, ) +from sympy.core.function import UndefinedFunction from sympy.core.relational import Relational +from lighteval.utils.imports import requires_latex2sympy2_extended from lighteval.utils.timeout import timeout @@ -66,7 +73,7 @@ def safe_sympy_doit(a: Basic | MatrixBase): return a.doit() except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass return a @@ -129,7 +136,7 @@ def sympy_numeric_eq(a: Basic | MatrixBase, b: Basic | MatrixBase, precision: in return (a - b).evalf(chop=True) == 0 # type: ignore except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass return False @@ -153,13 +160,13 @@ def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: return True except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass return False -def sympy_deep_compare_finite_set(a: FiniteSet, b: FiniteSet, precision: int) -> bool: +def sympy_deep_compare_set_and_tuple(gold: FiniteSet | Tuple, pred: FiniteSet | Tuple, precision: int) -> bool: """Compare two finite sets by comparing each element with given precision. Args: @@ -169,28 +176,154 @@ def sympy_deep_compare_finite_set(a: FiniteSet, b: FiniteSet, precision: int) -> Returns: True if sets contain equal elements within precision, False otherwise + + Note: in order to fully support finite sets, we should ideally do kartesian product comparison + but this is not implemented yet. We kinda hope sympy will order the elements. """ + from latex2sympy2_extended.sets import FiniteSet as L2SFiniteSet + + def unwrap_eq(s): + if is_assignment_relation(s): + return take_last_relation(s).rhs + return s + + def sort_key(x): + try: + return default_sort_key(unwrap_eq(x).evalf()) + except TimeoutError: + raise + except Exception: # noqa: E722 + return default_sort_key(unwrap_eq(x)) + # This ensures it works for {1/3} and {0.333333} - if len(a) == len(b) and all(sympy_expr_eq(a, b, precision) for a, b in zip(a, b)): + if len(gold) == len(pred): + if isinstance(gold, FiniteSet): + gold_args = list(ordered(gold.args, keys=sort_key, default=False)) + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) + + elif isinstance(gold, Tuple) and isinstance(pred, L2SFiniteSet): + # We treat the pred as tuple too + pred_args = pred._unsorted_args + gold_args = gold.args + + elif isinstance(pred, FiniteSet): + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) + gold_args = gold.args + else: + gold_args = gold.args + pred_args = pred.args + + return all(sympy_expr_eq(a, b, precision) for a, b in zip(gold_args, pred_args)) + + return False + + +def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool: + """Compare two sympy expressions where at least one is a Symbol. + + Handles special cases: + - One is Symbol and other is E (limitation of parsed expressions) + - One is multiplication of symbols and other is single symbol (concatenated comparison) + """ + # Handle E vs symbol case + if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or ( + isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E + ): + return True + + # Handle multiplication of symbols vs single symbol + if ( + isinstance(gold, Symbol) + and isinstance(pred, Mul) + and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args) + ): + concat_pred = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args) + return gold.name.lower() == concat_pred.lower() + + if ( + isinstance(pred, Symbol) + and isinstance(gold, Mul) + and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args) + ): + concat_gold = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args) + return pred.name.lower() == concat_gold.lower() + + return gold == pred + + +def is_relation(expr: Basic | MatrixBase) -> bool: + """Check if an expression is a relational expression.""" + if isinstance(expr, Relational): return True + if isinstance(expr, And): + return all(isinstance(arg, Relational) for arg in expr.args) + return False -def sympy_compare_set_interval(a: FiniteSet, b: Interval, precision: int) -> bool: - """Compare a finite set with an interval. +def take_last_relation(expr: And | Relational) -> Relational: + """Take the last relation from an And expression.""" + if isinstance(expr, And): + return take_last_relation(expr.args[-1]) + return expr + + +def unwrap_fcs(expr: Basic | MatrixBase) -> Basic | MatrixBase: + """Unwrap function calls to their arguments.""" + if not isinstance(expr, Basic): + return expr + + if hasattr(expr, "func") and isinstance(expr.func, UndefinedFunction): + func_name = expr.func.__name__ + unwrapped_args = [str(unwrap_fcs(arg)) for arg in expr.args] + return Symbol(f"{func_name}_{'_'.join(unwrapped_args)}") + + try: + new_args = [unwrap_fcs(arg) for arg in expr.args] + if new_args: + return expr.func(*new_args) + except TimeoutError: + raise + except Exception: # noqa: E722 + pass + + return expr + + +def is_equation(expr: Basic | MatrixBase) -> bool: + """Check if an expression is an equation. Args: - a: Finite set to compare - b: Interval to compare - precision: Number of decimal places to compare + expr: The expression to check + Returns: + bool: True if expr is an equation, False otherwise + """ + if isinstance(expr, Eq): + return True + + if isinstance(expr, And) and len(expr.args) > 0: + return all(isinstance(arg, Eq) for arg in expr.args) + return False + + +@requires_latex2sympy2_extended +def is_assignment_relation(expr: Basic | MatrixBase) -> bool: + from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols + + """Check if an expression is an assignment relation. E.g a=1 + + Args: + expr: The expression to check Returns: - True if set and interval are equivalent, False otherwise + bool: True if expr is a relational expression or And of relations, False otherwise """ - # Only compare if it's the special case of 2 elements - if len(a) == 2 and b.is_open: - return sympy_deep_compare_finite_set(a, FiniteSet(b.start, b.end), precision) + if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs): + return True + + if isinstance(expr, And) and len(expr.args) > 0: + return all(isinstance(arg, Eq) for arg in expr.args) and is_expr_of_only_symbols(expr.args[0].lhs) return False @@ -214,7 +347,7 @@ def sympy_compare_interval(a: Interval, b: Interval, precision: int) -> bool: ) -def sympy_compare_relational(gold: Relational, pred: Relational, precision: int) -> bool: +def sympy_compare_relational(gold: Relational | And, pred: Relational | And, precision: int) -> bool: """Compare two relational expressions. Args: @@ -225,6 +358,9 @@ def sympy_compare_relational(gold: Relational, pred: Relational, precision: int) Returns: True if relations are equivalent, False otherwise """ + # Handle And expressions by comparing each relation + if isinstance(gold, And): + return all(sympy_compare_relational(g, p, precision) for g, p in zip(gold.args, pred.args)) # Helper to check if expressions are equivalent when flipped def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool: @@ -232,18 +368,17 @@ def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool: return sympy_expr_eq(a.lhs - a.rhs, b.rhs - b.lhs, precision) # type: ignore except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass return False # Same type of relation (e.g. both <= or both >=) - try: if type(gold) == type(pred) and sympy_expr_eq(gold.lhs - gold.rhs, pred.lhs - pred.rhs, precision): # type: ignore return True except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass # Check flipped inequalities (a <= b equals b >= a) @@ -286,12 +421,14 @@ def sympy_str_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: return True except TimeoutError: raise - except: # noqa: E722 + except Exception: # noqa: E722 pass return False -def sympy_compare_sets(gold: Set | Basic | MatrixBase, pred: Set | Basic | MatrixBase, precision: int) -> bool: +def sympy_compare_sets( + gold: Set | Basic | MatrixBase | Tuple, pred: Set | Basic | MatrixBase | Tuple, precision: int +) -> bool: """Compare two sympy sets for equality using multiple methods. Args: @@ -303,8 +440,8 @@ def sympy_compare_sets(gold: Set | Basic | MatrixBase, pred: Set | Basic | Matri True if sets are equal by any comparison method, False otherwise """ # Convert non-sets to singleton sets - a_set = gold if isinstance(gold, Set) else FiniteSet(gold) - b_set = pred if isinstance(pred, Set) else FiniteSet(pred) + a_set = gold if isinstance(gold, (Set, Tuple)) else FiniteSet(gold) + b_set = pred if isinstance(pred, (Set, Tuple)) else FiniteSet(pred) # If both are intervals, use interval comparison if isinstance(a_set, Interval) and isinstance(b_set, Interval): @@ -313,57 +450,88 @@ def sympy_compare_sets(gold: Set | Basic | MatrixBase, pred: Set | Basic | Matri # Try direct set equality if a_set == b_set: return True - if a_set.symmetric_difference(b_set).is_empty: + + # If both are sets, check if they are equal + if isinstance(a_set, Set) and isinstance(b_set, Set) and a_set.symmetric_difference(b_set).is_empty: return True # For finite sets, compare elements - if isinstance(a_set, FiniteSet) and isinstance(b_set, FiniteSet): - return sympy_deep_compare_finite_set(a_set, b_set, precision) + if isinstance(a_set, (FiniteSet, Tuple)) and isinstance(b_set, (FiniteSet, Tuple)): + return sympy_deep_compare_set_and_tuple(a_set, b_set, precision) + + # Because (1,2) is parsed as Interval(1,2,left_open=True,right_open=True), it could have that the + # correct is (1,2) and predicted is 1,2, which is parsed as Set(1,2) + if isinstance(a_set, Interval) and isinstance(b_set, (FiniteSet, Tuple)): + if a_set.is_open and len(b_set) == 2: + return sympy_deep_compare_set_and_tuple(Tuple(a_set.start, a_set.end), b_set, precision) - # Handle interval vs finite set cases - if isinstance(a_set, Interval) and isinstance(b_set, FiniteSet): - return sympy_compare_set_interval(b_set, a_set, precision) - if isinstance(a_set, FiniteSet) and isinstance(b_set, Interval): - return sympy_compare_set_interval(a_set, b_set, precision) + if isinstance(b_set, Interval) and isinstance(a_set, (FiniteSet, Tuple)): + if b_set.is_open and len(a_set) == 2: + return sympy_deep_compare_set_and_tuple(a_set, Tuple(b_set.start, b_set.end), precision) return False -def sympy_expr_eq(gold: Basic | MatrixBase, pred: Basic | MatrixBase, precision: int) -> bool: +def sympy_expr_eq(gold: Basic | MatrixBase, pred: Basic | MatrixBase, precision: int, strict: bool = True) -> bool: # noqa: C901 """Compare two sympy expressions for equality using multiple methods. Args: gold: First sympy expression (expected) pred: Second sympy expression (predicted) precision: Number of decimal places to compare + strict: If true, variables do matter otherwise they don't Returns: True if expressions are equal by any comparison method, False otherwise """ - # If the reference is relational, but the target is not, it's possible it's a case of answer=x+1+z, so we just take x+1+z - # We assume that the gold never needs to be simplified, so we don't handle that case - # e.g 1+1+1=3 will never be simplified to 3; it would be possible to do so with lhs-rhs == 0, but we assume the gold is at its most simplified form. - # The new latex2sympy2 will actually convert such cases automatically, but so this is in theory not needed - if isinstance(gold, Eq) and not isinstance(pred, Relational) and isinstance(gold.lhs, Symbol): - gold = gold.rhs + # This ensures that f(x) == f(y) is true + if not strict: + try: + gold_variables = gold.free_symbols + pred_variables = pred.free_symbols + if len(gold_variables) == len(pred_variables): + pred = pred.subs(list(zip(pred_variables, gold_variables))) + except TimeoutError: + raise + except Exception: # noqa: E722 + pass + + # If the target is relational, but the refernce is not, it's possible it's a case of a=x+1+z, so we just take x+1+z + # We only do this if the lhs of the first equation is fully symbolic, to prevent simplifying x+y+2z = 1 + if is_assignment_relation(gold) and not is_equation(pred): + gold = take_last_relation(gold).rhs # Here we respect the gold and simplify accordingly, thus any of # k=x+1+z or 1+1+1=3 will be simplified to rhs - if isinstance(pred, Eq) and not isinstance(gold, Eq): - pred = pred.rhs + if is_equation(pred) and not is_equation(gold): + pred = take_last_relation(pred).rhs + + if is_relation(gold) and isinstance(pred, Set): + # This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo) + # We also unwrap the functions because otherwise it creates some conditional set based on the function name + try: + gold = unwrap_fcs(gold).as_set() + except TimeoutError: + raise + except Exception: # noqa: E722 + pass - # Start with simple str and expr comparisson as it's the fastest - # str comparison is better, than simple eq, because it will also handle missarangments + # Start with simple str and expr comparison as it's the fastest + # str comparison is better than simple eq, because it will also handle misarrangements if sympy_str_eq(gold, pred): return True # Support for equations - if isinstance(gold, Relational) and isinstance(pred, Relational): + if is_relation(gold) and is_relation(pred): return sympy_compare_relational(gold, pred, precision) - elif isinstance(gold, Set) or isinstance(pred, Set): + elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)): return sympy_compare_sets(gold, pred, precision) + # Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r) + elif isinstance(gold, Symbol) or isinstance(pred, Symbol): + return sympy_compare_symbols(gold, pred) + elif isinstance(gold, (Basic, MatrixBase)) and isinstance(pred, (Basic, MatrixBase)): # Mostly so that 0.333333 = 1/3 if sympy_numeric_eq(gold, pred, precision): @@ -411,16 +579,20 @@ def should_treat_as_complex(latex_str: str) -> bool: def compare_gold_target( - gold: list[Basic | MatrixBase | str], target: list[Basic | MatrixBase | str], precision: int + gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, + target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, + precision: int = 6, + strict: bool = True, + timeout_seconds: int = 3, ) -> bool: - @timeout(timeout_seconds=10) + @timeout(timeout_seconds=timeout_seconds) def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool: # If both are sympy expressions, we can use sympy to compare them if isinstance(gold, (Basic, MatrixBase)) and isinstance(target, (Basic, MatrixBase)): - return sympy_expr_eq(gold, target, precision) + return sympy_expr_eq(gold, target, precision, strict) # We don't support str / sympy.Expr comparison. Imo there is no point in doing this, as chances - # of this happening are very low. The only why one of them is not converted to sympy expression + # of this happening are very low. The only reason why one of them is not converted to sympy expression # is usually because the parsing logic failed in this case we should improve the parsing logic # instead of somehow fixing adhoc. elif isinstance(gold, str) and isinstance(target, str): @@ -436,7 +608,7 @@ def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | Ma def compare_single_extraction_wrapper(g, t): try: return compare_single_extraction(g, t) - except TimeoutError: + except Exception: # noqa: E722 return False return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target)) diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index fbac41c4d..faf47bdf7 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -42,6 +42,7 @@ class TranslationLiterals: cause_word: str = None # type: ignore effect_word: str = None # type: ignore or_word: str = None # type: ignore + and_word: str = None # type: ignore # NLI true: str = None # type: ignore @@ -91,6 +92,7 @@ def __getattribute__(self, name: str) -> str: false="خاطئ", neither="لا هذا ولا ذاك", or_word="أو", + and_word="و", full_stop=".", comma="،", question_mark="؟", @@ -216,6 +218,7 @@ def __getattribute__(self, name: str) -> str: false="假", neither="都不是", or_word="或", + and_word="和", full_stop="。", comma=",", question_mark="?", @@ -317,6 +320,7 @@ def __getattribute__(self, name: str) -> str: sentence_space=" ", colon=":", or_word="or", + and_word="and", ), Language.ESPERANTO: TranslationLiterals(language=Language.ESPERANTO), Language.ESTONIAN: TranslationLiterals( @@ -359,6 +363,7 @@ def __getattribute__(self, name: str) -> str: cause_word="parce que", effect_word="donc", or_word="ou", + and_word="et", true="vrai", false="faux", neither="aucun des deux", @@ -462,6 +467,7 @@ def __getattribute__(self, name: str) -> str: false="असत्य", neither="न तो यह, न वह", or_word="या", + and_word="और", full_stop="।", comma=",", question_mark="?", diff --git a/tests/metrics/test_extractive_match.py b/tests/metrics/test_extractive_match.py index 78e7fdae2..fb163bf3e 100644 --- a/tests/metrics/test_extractive_match.py +++ b/tests/metrics/test_extractive_match.py @@ -336,7 +336,7 @@ def test_sets_handling(gold, pred, expected): ("$1/3$", "$\\frac{1}{3} \\text{meters}$", 1), ("$1/3$", "$\\frac{1}{3} \\textbf{meters}$", 1), # Last = is considered - ("$1/3$", "$\\k = \\frac{1}{3}$", 1), + ("$1/3$", "$k = \\frac{1}{3}$", 1), ("$1/3$", "$\\frac{1}{3} \\textbf{meters}$", 1), ], ) @@ -604,12 +604,6 @@ def test_latex_notation_math(gold, pred, expected): "$-x >= -1$", 1, ), - # Test incomplete equation - ( - "$a +z = 0$", - "$0$", - 0, - ), ], ) def test_relations_math(gold, pred, expected): @@ -864,7 +858,8 @@ def test_math_extraction_edge_cases(gold, pred, expected): r"Since $AP:PB = 1:4,$ we can write \[\frac{\overrightarrow{A} - \overrightarrow{P}}{1} = \frac{\overrightarrow{B} - \overrightarrow{P}}{4}.\]Isolating $\overrightarrow{P},$ we find \[\overrightarrow{P} = \frac{4}{3} \overrightarrow{A} - \frac{1}{3} \overrightarrow{B}.\]Thus, $(t,u) = \boxed{\left( \frac{4}{3}, -\frac{1}{3} \right)}.$", 1, ), - (r"$(3,1)$", r"${1,3}$", 1), + # Shouldn't work as it's ordered tuple vs set + # (r"$(3,1)$", r"${1,3}$", 1), (r"$(1,3)$", r"${1,3}$", 1), # Issue: Therefore preference ( @@ -975,8 +970,158 @@ def test_math_extraction_additional_cases(gold, pred, expected): assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected -# text{C} Qwen correct -# 11111111100 Qwen correct -# Interval(2, oo) qwen incorrect -# text{west} qwen incorrect -# 32349, 32,\!348 qwen incorrect +@pytest.mark.parametrize( + "gold, pred, expected", + [ + ( + r"$(37,3,3,13),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 1, + ), + ( + r"$(p,q)=(3,2)$", + r"$(3,2)$", + 1, + ), + ( + r"$(0;0;0),(0;-2;0),(0;0;6),(0;-2;6),(4;0;0),(4;-2;0),(4;0;6),(4;-2;6)$", + r"\boxed{(4, 0, 6), (4, -2, 6), (0, 0, 6), (0, -2, 6), (4, 0, 0), (4, -2, 0), (0, 0, 0), (0, -2, 0)}", + 1, + ), + ( + r"$1\leq|z|\leq \frac{3}{2}$", + r"$z \in \left[-\frac{3}{2}, -1\right] \cup \left[1, \frac{3}{2}\right]$", + 1, + ), + ( + r"$-12;-11;-10;-8;-7;-6$", + r"$\boxed{\{-12, -11, -10, -8, -7, -6\}}$", + 1, + ), + ( + r"$AB=4,CD=5$", + r"$\boxed{4, 5}$", + 1, + ), + ( + r"$(11,7)or(7,11)$", + r"$\boxed{(7,11),\ (11,7)}$", + 1, + ), + ( + r"$S_{MBCN}:S=7:32$", + r"$\boxed{7:32}$", + 1, + ), + ( + r"$\frac{NO}{BO}=\frac{1}{\sqrt{6}}$", + r"$\frac{1}{\sqrt{6}}$", + 1, + ), + ( + r"$p=5,q=2;p=7,q=2$", + r"$(5,2),(7,2)$", + 1, + ), + ( + r"$(p,q,r)=(3,2,7)$", + r"$(3,2,7)$", + 1, + ), + ( + r"$V_{1}:V_{2}=11:21$", + r"$11:21$", + 1, + ), + ( + r"$(2,1),(1,2),(-1,-20),(-20,-1)$", + r"solutions are:\n\n\\[\n\\boxed{(1, 2)}, \\boxed{(2, 1)}, \\boxed{(-1, -20)}, \\boxed{(-20, -1)}\n\\]", + 1, + ), + ( + r"\(\boxed{1}\) and \(\boxed{-2}\).", + r"$\boxed{-2,1}$.", + 1, + ), + ( + r"$\text{odd}$", + r"$odd$", + 1, + ), + ( + r"$\text{e}$", + r"$e$", + 1, + ), + ( + r"$\text{E}$", + r"$E$", + 1, + ), + (r"$d$", r"$\text{E}$", 0), + (r"$1$ and $2$ and $3$", r"$\boxed{1,2,3}$", 1), + ( + r"$(37,3,3,13),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 1, + ), + ( + r"$(37,3,3),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 0, + ), + ( + r"$(p,q)=(3,2)$", + r"$\boxed{(3, 2)}$", + 1, + ), + ( + r"\boxed{x = -5,\ p = \frac{14}{3}} ", + r"$\boxed{-5, \frac{14}{3}}$", + 1, + ), + ( + r"\boxed{a=4,\,-8,\,-10}", + r"$\boxed{-10,-8,4}$", + 1, + ), + ( + r"\\boxed{W(n) = 1 \\text{ and } W(n) = -1", + r"W(x)=1orW(x)=-1", + 1, + ), + ("$21,16$ or $11$", "$21,16,11$", 1), + (r"\boxed{ p = 5, q = 2 \quad \text{and} \quad p = 7, q = 2}", r"$p=5,q=2;p=7,q=2$", 1), + (r"\n\n\[ \boxed{p = -1 \text{ and } p = \dfrac{15}{8}} \]", r"$p=-1,p=\frac{15}{8}$", 1), + ("$0 Date: Wed, 5 Feb 2025 16:39:47 +0100 Subject: [PATCH 2/4] Add GPQA for instruct models (#534) * Add GPQA for instruct models * Add ref * Refactor * Tune prompt * Tune max tokens * Use simple-eval template --- src/lighteval/logging/evaluation_tracker.py | 3 +- src/lighteval/metrics/metrics.py | 11 +++++ src/lighteval/tasks/default_prompts.py | 17 ++++++++ src/lighteval/tasks/default_tasks.py | 48 +++++++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 2364b470e..b404ffe4a 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -325,12 +325,13 @@ def push_to_hub( # We upload it both as a json and a parquet file result_file_base_name = f"results_{date_id}" results_json = json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False) - self.api.upload_file( + url = self.api.upload_file( repo_id=repo_id, path_or_fileobj=BytesIO(results_json.encode("utf-8")), path_in_repo=f"{result_file_base_name}.json", repo_type="dataset", ) + logger.info(f"Uploaded evaluation details to {url}") results_dataset = Dataset.from_dict( {key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()} diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index ff4b6b059..171219197 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -24,6 +24,10 @@ import numpy as np from aenum import Enum +from lighteval.metrics.dynamic_metrics import ( + IndicesExtractionConfig, + multilingual_extractive_match_metric, +) from lighteval.metrics.harness_compatibility.drop import drop_metrics from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics from lighteval.metrics.metrics_corpus import ( @@ -69,6 +73,7 @@ SampleLevelMetric, SampleLevelMetricGrouping, ) +from lighteval.utils.language import Language from lighteval.utils.utils import as_list @@ -549,6 +554,12 @@ class Metrics(Enum): corpus_level_fn=CorpusLevelPerplexityMetric("weighted_perplexity").compute, higher_is_better=False, ) + gpqa_instruct_metric = multilingual_extractive_match_metric( + language=Language.ENGLISH, + gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], + pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], + precision=6, + ) def __str__(self): return self.name.replace("_at_", "@") diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 66c3d53b4..ace39ed70 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -729,6 +729,23 @@ def gpqa(line, task_name: str = None): ) +def gpqa_instruct(line, task_name: str = None): + """Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14""" + gold_index = random.randint(0, 3) + choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] + choices.insert(gold_index, line["Correct Answer"]) + query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"]) + + return Doc( + task_name=task_name, + query=query, + choices=LETTER_INDICES[: len(choices)], + gold_index=gold_index, + instruction=query, + ) + + def gsm8k(line, task_name: str = None): # Has special analysis in metric for number decomposition return Doc( diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index d6a7ec498..92f481e51 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -7720,6 +7720,54 @@ trust_dataset=True, version=0, ) +gpqa_diamond_instruct_lighteval = LightevalTaskConfig( + name="gpqa:diamond", + suite=["lighteval"], + prompt_function=prompt.gpqa_instruct, + hf_repo="Idavidrein/gpqa", + hf_subset="gpqa_diamond", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split=None, + few_shots_select=None, + generation_size=32768, # needed for reasoning models like R1 + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=[], # no stop sequence, will use eos token + trust_dataset=True, + version=0, +) +gpqa_extended_instruct_lighteval = LightevalTaskConfig( + name="gpqa:extended", + suite=["lighteval"], + prompt_function=prompt.gpqa_instruct, + hf_repo="Idavidrein/gpqa", + hf_subset="gpqa_extended", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split=None, + few_shots_select=None, + generation_size=32768, # needed for reasoning models like R1 + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=[], # no stop sequence, will use eos token + trust_dataset=True, + version=0, +) +gpqa_main_instruct_lighteval = LightevalTaskConfig( + name="gpqa:main", + suite=["lighteval"], + prompt_function=prompt.gpqa_instruct, + hf_repo="Idavidrein/gpqa", + hf_subset="gpqa_main", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split=None, + few_shots_select=None, + generation_size=32768, # needed for reasoning models like R1 + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=[], # no stop sequence, will use eos token + trust_dataset=True, + version=0, +) gre_reading_comprehension_bigbench = LightevalTaskConfig( name="gre_reading_comprehension", suite=["bigbench", "bigbench_json"], From 15bdbb8125f7cd4f8379756e95c033b5006245b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 6 Feb 2025 08:57:24 +0100 Subject: [PATCH 3/4] Make BLEURT lazy (#536) * make bleur lazy * make tokenizer lazy too --- src/lighteval/metrics/metrics_sample.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 352c2b98e..225903c08 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -708,9 +708,21 @@ def __init__(self): """Creates a BLEURT scorer using a light bleurt-tiny-512 model. For more complex use cases, could also be Elron/bleurt-base-128 """ - self.tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512") - self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512") - self.model.eval() + self._tokenizer = None + self._model = None + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512") + return self._tokenizer + + @property + def model(self): + if self._model is None: + self._model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512") + self._model.eval() + return self._model def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float: """Uses the stored BLEURT scorer to compute the score on the current sample. From 441d7a4a83fdf27d0e362f77ed15295908be4cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Thu, 6 Feb 2025 08:57:38 +0100 Subject: [PATCH 4/4] Pass@k (#519) * init * correct typing * added defaults * small fix --- src/lighteval/metrics/metrics.py | 25 +++++ src/lighteval/metrics/metrics_sample.py | 117 +++++++++++++++++++++++- 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 171219197..290114783 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -48,6 +48,7 @@ Faithfulness, LoglikelihoodAcc, MajAtK, + PassAtK, Recall, StringDistance, acc_golds_likelihood, @@ -369,6 +370,30 @@ class Metrics(Enum): corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute, higher_is_better=True, ) + pass_at_1 = SampleLevelMetric( + metric_name="pass@1:32_samples", + sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute, + category=MetricCategory.GENERATIVE_SAMPLING, + use_case=MetricUseCase.REASONING, + corpus_level_fn=np.mean, + higher_is_better=True, + ) + pass_at_10 = SampleLevelMetric( + metric_name="pass@10:32_samples", + sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute, + category=MetricCategory.GENERATIVE_SAMPLING, + use_case=MetricUseCase.REASONING, + corpus_level_fn=np.mean, + higher_is_better=True, + ) + pass_at_100 = SampleLevelMetric( + metric_name="pass@100:32_samples", + sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute, + category=MetricCategory.GENERATIVE_SAMPLING, + use_case=MetricUseCase.REASONING, + corpus_level_fn=np.mean, + higher_is_better=True, + ) perfect_exact_match = SampleLevelMetric( metric_name="perfect_em", sample_level_fn=ExactMatches().compute, diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 225903c08..5d89ec9e9 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -26,7 +26,7 @@ import logging import os -from typing import Callable, Literal +from typing import Callable, Literal, Union import nltk import numpy as np @@ -1055,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int: if self.type_exact_match == "suffix": return 1 if pred.endswith(gold) else 0 return 1 if gold == pred else 0 + + +class PassAtK: + def __init__( + self, + k: int, + n: int = None, + normalize_gold: Callable = None, + normalize_pred: Callable = None, + strip_strings: bool = False, + sample_scoring_function: Union[Callable[[str, str], float], str] = None, + ): + """Computing pass at k + + Args: + k (int): Threshold for the number of successful attempts. + n (int): Number of samples to generate + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. + sample_scoring_function (callable or str, optional): Function to use to score each sample. + Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1) + a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full". + `prefix` checks if the prediction starts with the gold, + `suffix` if the prediction ends with the gold, + `full` if the prediction and gold are equal + """ + self.k = k + self.n = n + self.normalize_gold = normalize_gold + self.normalize_pred = normalize_pred + self.strip_strings = strip_strings + + # Managed the logic of the per prediction of sample scoring + if callable(sample_scoring_function): + self.score_sample = sample_scoring_function + self.type_exact_match = None + else: + if isinstance(sample_scoring_function, str): + if sample_scoring_function not in ["prefix", "suffix", "full"]: + raise ValueError( + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." + ) + self.type_exact_match = sample_scoring_function + else: + self.type_exact_match = "full" + self.score_sample = self.default_sample_scoring + + def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]: + """Computes the metric over a list of golds and predictions for one single item with possibly many samples. + It applies normalisation (if needed) to model prediction and gold, computes their per prediction score, + then aggregates the scores over the samples using a pass@k. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): k predicted strings + + Returns: + float: Aggregated score over the current sample's items. + """ + if len(golds) > 1: + raise Exception("Cannot compute pass@k with several golds") + + if self.n is None: + self.n = len(predictions) + logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.") + elif len(predictions) < self.n: + logger.warning(f"Number of predictions is less than {self.n} for pass@k.") + + gold = self.get_processed_gold(golds[0]) + + all_scores = [] + for pred in predictions[: self.n]: + cur_pred = self.get_processed_pred(pred=pred) + all_scores.append(self.score_sample(cur_pred, gold)) + + return self.pass_at_k(all_scores) + + def get_processed_gold(self, gold: str) -> float: + if self.strip_strings: + gold = gold.strip() + + if self.normalize_gold: + gold = self.normalize_gold(gold) + + return gold + + def get_processed_pred(self, pred: str) -> float: + if not pred: + return "" + + if self.strip_strings: + pred = pred.strip() + + if self.normalize_pred: + pred = self.normalize_pred(pred) + + return pred + + def default_sample_scoring(self, pred: str, gold: str) -> int: + if self.type_exact_match == "prefix": + return 1 if pred.startswith(gold) else 0 + if self.type_exact_match == "suffix": + return 1 if pred.endswith(gold) else 0 + return 1 if gold == pred else 0 + + def pass_at_k(self, all_scores: list[int]) -> float: + """Algo from https://arxiv.org/pdf/2107.03374""" + c: int = all_scores.count(1) + if self.n - c < self.k: + return 1.0 + + return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))