Skip to content

Commit

Permalink
Merge branch 'main' into nathan-fix-vllm-from-file
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanHB authored Feb 6, 2025
2 parents b4c2d77 + 441d7a4 commit 147211c
Show file tree
Hide file tree
Showing 11 changed files with 811 additions and 131 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.3"]
math = ["latex2sympy2_extended==1.0.4"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ def push_to_hub(
# We upload it both as a json and a parquet file
result_file_base_name = f"results_{date_id}"
results_json = json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)
self.api.upload_file(
url = self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=BytesIO(results_json.encode("utf-8")),
path_in_repo=f"{result_file_base_name}.json",
repo_type="dataset",
)
logger.info(f"Uploaded evaluation details to {url}")

results_dataset = Dataset.from_dict(
{key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()}
Expand Down
19 changes: 15 additions & 4 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
precision: int = 6,
timeout_seconds: int = 5,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Expand Down Expand Up @@ -222,6 +223,8 @@ def multilingual_extractive_match_metric(
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
timeout_seconds: int
Timeout for the extraction (each attempt) and comparison. Defaults to 5.
Returns:
A sample level metric that extracts and compares mathematical expressions.
Expand All @@ -245,11 +248,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
for pred in predictions
]
extracted_golds = [
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
for gold in golds
]

# Assert on empty gold and warn on empty pred
Expand All @@ -265,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
except Exception: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
(
1.0
if any(
compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds)
for gold in extracted_golds
)
else 0.0
)
for pred in extracted_predictions
]
)
Expand Down
36 changes: 36 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import numpy as np
from aenum import Enum

from lighteval.metrics.dynamic_metrics import (
IndicesExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.metrics.harness_compatibility.drop import drop_metrics
from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics
from lighteval.metrics.metrics_corpus import (
Expand All @@ -44,6 +48,7 @@
Faithfulness,
LoglikelihoodAcc,
MajAtK,
PassAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand All @@ -69,6 +74,7 @@
SampleLevelMetric,
SampleLevelMetricGrouping,
)
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list


Expand Down Expand Up @@ -364,6 +370,30 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
pass_at_1 = SampleLevelMetric(
metric_name="pass@1:32_samples",
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_10 = SampleLevelMetric(
metric_name="pass@10:32_samples",
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_100 = SampleLevelMetric(
metric_name="pass@100:32_samples",
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
perfect_exact_match = SampleLevelMetric(
metric_name="perfect_em",
sample_level_fn=ExactMatches().compute,
Expand Down Expand Up @@ -549,6 +579,12 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelPerplexityMetric("weighted_perplexity").compute,
higher_is_better=False,
)
gpqa_instruct_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
precision=6,
)

def __str__(self):
return self.name.replace("_at_", "@")
Expand Down
135 changes: 131 additions & 4 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import logging
import os
from typing import Callable, Literal
from typing import Callable, Literal, Union

import nltk
import numpy as np
Expand Down Expand Up @@ -708,9 +708,21 @@ def __init__(self):
"""Creates a BLEURT scorer using a light bleurt-tiny-512 model.
For more complex use cases, could also be Elron/bleurt-base-128
"""
self.tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512")
self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512")
self.model.eval()
self._tokenizer = None
self._model = None

@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512")
return self._tokenizer

@property
def model(self):
if self._model is None:
self._model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512")
self._model.eval()
return self._model

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float:
"""Uses the stored BLEURT scorer to compute the score on the current sample.
Expand Down Expand Up @@ -1043,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int:
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0


class PassAtK:
def __init__(
self,
k: int,
n: int = None,
normalize_gold: Callable = None,
normalize_pred: Callable = None,
strip_strings: bool = False,
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
):
"""Computing pass at k
Args:
k (int): Threshold for the number of successful attempts.
n (int): Number of samples to generate
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
sample_scoring_function (callable or str, optional): Function to use to score each sample.
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
`prefix` checks if the prediction starts with the gold,
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
self.k = k
self.n = n
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings

# Managed the logic of the per prediction of sample scoring
if callable(sample_scoring_function):
self.score_sample = sample_scoring_function
self.type_exact_match = None
else:
if isinstance(sample_scoring_function, str):
if sample_scoring_function not in ["prefix", "suffix", "full"]:
raise ValueError(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
)
self.type_exact_match = sample_scoring_function
else:
self.type_exact_match = "full"
self.score_sample = self.default_sample_scoring

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
then aggregates the scores over the samples using a pass@k.
Args:
golds (list[str]): Reference targets
predictions (list[str]): k predicted strings
Returns:
float: Aggregated score over the current sample's items.
"""
if len(golds) > 1:
raise Exception("Cannot compute pass@k with several golds")

if self.n is None:
self.n = len(predictions)
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
elif len(predictions) < self.n:
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")

gold = self.get_processed_gold(golds[0])

all_scores = []
for pred in predictions[: self.n]:
cur_pred = self.get_processed_pred(pred=pred)
all_scores.append(self.score_sample(cur_pred, gold))

return self.pass_at_k(all_scores)

def get_processed_gold(self, gold: str) -> float:
if self.strip_strings:
gold = gold.strip()

if self.normalize_gold:
gold = self.normalize_gold(gold)

return gold

def get_processed_pred(self, pred: str) -> float:
if not pred:
return ""

if self.strip_strings:
pred = pred.strip()

if self.normalize_pred:
pred = self.normalize_pred(pred)

return pred

def default_sample_scoring(self, pred: str, gold: str) -> int:
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0

def pass_at_k(self, all_scores: list[int]) -> float:
"""Algo from https://arxiv.org/pdf/2107.03374"""
c: int = all_scores.count(1)
if self.n - c < self.k:
return 1.0

return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))
Loading

0 comments on commit 147211c

Please sign in to comment.