Skip to content

Commit

Permalink
Merge branch 'main' into Document-Custom-Model-Files
Browse files Browse the repository at this point in the history
  • Loading branch information
ParagEkbote authored Feb 6, 2025
2 parents ab38bd5 + 86f6225 commit d809e39
Show file tree
Hide file tree
Showing 18 changed files with 1,134 additions and 159 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ pre-commit run --all-files

```bibtex
@misc{lighteval,
author = {Fourrier, Clémentine and Habib, Nathan and Wolf, Thomas and Tunstall, Lewis},
author = {Fourrier, Clémentine and Habib, Nathan and Kydlíček, Hynek and Wolf, Thomas and Tunstall, Lewis},
title = {LightEval: A lightweight framework for LLM evaluation},
year = {2023},
version = {0.5.0},
version = {0.7.0},
url = {https://github.com/huggingface/lighteval}
}
```
39 changes: 37 additions & 2 deletions community_tasks/french_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from lighteval.tasks.extended.ifeval.main import ifeval_metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list


# Ifeval-fr prompt function
Expand All @@ -70,7 +71,7 @@ def prompt_gpqa_fr(line, task_name: str = None):

query = f"Question: {line['Question']}\n"
query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)])
query += "Answer: "
query += "Réponse: "
return Doc(
task_name=task_name,
query=f"{instruction}{query}",
Expand All @@ -80,6 +81,22 @@ def prompt_gpqa_fr(line, task_name: str = None):
)


# BAC-fr prompt function
def prompt_bac_fr(line, task_name: str = None):
prompt = f"Enoncé: {line['enonce']}\n{line['instruction']}\n"
if line["choix"] is not None: # Multichoice evaluation
# prompt += "\n".join([f"{LETTER_INDICES[ix]}.{choix}" for ix, choix in enumerate(line["choix"])])
return Doc(
task_name=task_name,
query=prompt,
choices=as_list(line["choix"]),
gold_index=line["choix"].index(line["choix correct"]),
instruction="",
)
else:
return Doc(task_name=task_name, query=prompt, choices=[line["reponse"]], gold_index=0, instruction="")


# IFEVal-fr task


Expand Down Expand Up @@ -117,5 +134,23 @@ def prompt_gpqa_fr(line, task_name: str = None):
version=0,
)

# BAC-fr task
bac_fr_task = LightevalTaskConfig(
name="bac-fr",
suite=["community"],
prompt_function=prompt_bac_fr,
hf_repo="fr-gouv-coordination-ia/bac-fr",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select="random_sampling",
generation_size=1,
metric=[Metrics.quasi_exact_match_math, Metrics.exact_match],
stop_sequence=["\n"],
trust_dataset=True,
version=0,
)

# STORE YOUR EVALS
TASKS_TABLE = [ifeval_fr_task, gpqa_fr_task]
TASKS_TABLE = [ifeval_fr_task, gpqa_fr_task, bac_fr_task]
49 changes: 33 additions & 16 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,51 @@ Lighteval can be used with a few different commands.
- `tgi`: evaluate models on one or more GPUs using [🔗 Text Generation Inference](https://huggingface.co/docs/text-generation-inference/en/index)
- `openai`: evaluate models on one or more GPUs using [🔗 OpenAI API](https://platform.openai.com/)

## Accelerate
## Basic usage

### Evaluate a model on a GPU

To evaluate `GPT-2` on the Truthful QA benchmark, run:
To evaluate `GPT-2` on the Truthful QA benchmark with [🤗
Accelerate](https://github.com/huggingface/accelerate) , run:

```bash
lighteval accelerate \
"pretrained=gpt2" \
"leaderboard|truthfulqa:mc|0|0"
```

Here, `--tasks` refers to either a comma-separated list of supported tasks from
the [tasks_list](available-tasks) in the format:
Here, we first choose a backend (either `accelerate`, `nanotron`, or `vllm`), and then specify the model and task(s) to run.

```bash
{suite}|{task}|{num_few_shot}|{0 or 1 to automatically reduce `num_few_shot` if prompt is too long}
The syntax for the model arguments is `key1=value1,key2=value2,etc`.
Valid key-value pairs correspond with the backend configuration, and are detailed [below](#Model Arguments).

The syntax for the task specification might be a bit hard to grasp at first. The format is as follows:

```txt
{suite}|{task}|{num_few_shot}|{0 for strict `num_few_shots`, or 1 to allow a truncation if context size is too small}
```

or a file path like
[examples/tasks/recommended_set.txt](https://github.com/huggingface/lighteval/blob/main/examples/tasks/recommended_set.txt)
which specifies multiple task configurations.
If the fourth value is set to 1, lighteval will check if the prompt (including the few-shot examples) is too long for the context size of the task or the model.
If so, the number of few shot examples is automatically reduced.

Tasks details can be found in the
All officially supported tasks can be found at the [tasks_list](available-tasks) and in the
[extended folder](https://github.com/huggingface/lighteval/tree/main/src/lighteval/tasks/extended).
Moreover, community-provided tasks can be found in the
[community](https://github.com/huggingface/lighteval/tree/main/community_tasks) folder.
For more details on the implementation of the tasks, such as how prompts are constructed, or which metrics are used, you can have a look at the
[file](https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/default_tasks.py)
implementing them.

### Evaluate a model on one or more GPUs
Running multiple tasks is supported, either with a comma-separated list, or by specifying a file path.
The file should be structured like [examples/tasks/recommended_set.txt](https://github.com/huggingface/lighteval/blob/main/examples/tasks/recommended_set.txt).
When specifying a path to file, it should start with `./`.

```bash
lighteval accelerate \
"pretrained=gpt2" \
./path/to/lighteval/examples/tasks/recommended_set.txt
# or, e.g., "leaderboard|truthfulqa:mc|0|0|,leaderboard|gsm8k|3|1"
```

## Evaluate a model on one or more GPUs

#### Data parallelism

Expand Down Expand Up @@ -86,13 +103,13 @@ This will automatically use accelerate to distribute the model across the GPUs.
> `model_parallel=True` and using accelerate to distribute the data across the
GPUs.

### Model Arguments
## Backend configuration

The `model-args` argument takes a string representing a list of model
argument. The arguments allowed vary depending on the backend you use (vllm or
accelerate).

#### Accelerate
### Accelerate

- **pretrained** (str):
HuggingFace Hub model ID name or the path to a pre-trained
Expand Down Expand Up @@ -128,7 +145,7 @@ accelerate).
- **trust_remote_code** (bool): Whether to trust remote code during model
loading.

#### VLLM
### VLLM

- **pretrained** (str): HuggingFace Hub model ID name or the path to a pre-trained model to load.
- **gpu_memory_utilisation** (float): The fraction of GPU memory to use.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ nanotron = [
"tensorboardX"
]
tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
vllm = ["vllm>=0.7.0", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
Expand All @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.3"]
math = ["latex2sympy2_extended==1.0.6"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ def push_to_hub(
# We upload it both as a json and a parquet file
result_file_base_name = f"results_{date_id}"
results_json = json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)
self.api.upload_file(
url = self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=BytesIO(results_json.encode("utf-8")),
path_in_repo=f"{result_file_base_name}.json",
repo_type="dataset",
)
logger.info(f"Uploaded evaluation details to {url}")

results_dataset = Dataset.from_dict(
{key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()}
Expand Down
22 changes: 17 additions & 5 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
precision: int = 6,
timeout_seconds: int = 5,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Expand Down Expand Up @@ -222,6 +223,8 @@ def multilingual_extractive_match_metric(
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
timeout_seconds: int
Timeout for the extraction (each attempt) and comparison. Defaults to 5.
Returns:
A sample level metric that extracts and compares mathematical expressions.
Expand All @@ -245,16 +248,18 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
for pred in predictions
]
extracted_golds = [
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
for gold in golds
]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")
logger.warning(f"We did not manage to extract a gold in the correct format. Gold: {golds}")
extracted_golds = [[gold] for gold in golds]

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(
Expand All @@ -264,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
except Exception: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
(
1.0
if any(
compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds)
for gold in extracted_golds
)
else 0.0
)
for pred in extracted_predictions
]
)
Expand Down
36 changes: 36 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import numpy as np
from aenum import Enum

from lighteval.metrics.dynamic_metrics import (
IndicesExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.metrics.harness_compatibility.drop import drop_metrics
from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics
from lighteval.metrics.metrics_corpus import (
Expand All @@ -44,6 +48,7 @@
Faithfulness,
LoglikelihoodAcc,
MajAtK,
PassAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand All @@ -69,6 +74,7 @@
SampleLevelMetric,
SampleLevelMetricGrouping,
)
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list


Expand Down Expand Up @@ -364,6 +370,30 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
pass_at_1 = SampleLevelMetric(
metric_name="pass@1:32_samples",
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_10 = SampleLevelMetric(
metric_name="pass@10:32_samples",
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
pass_at_100 = SampleLevelMetric(
metric_name="pass@100:32_samples",
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
perfect_exact_match = SampleLevelMetric(
metric_name="perfect_em",
sample_level_fn=ExactMatches().compute,
Expand Down Expand Up @@ -549,6 +579,12 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelPerplexityMetric("weighted_perplexity").compute,
higher_is_better=False,
)
gpqa_instruct_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
precision=6,
)

def __str__(self):
return self.name.replace("_at_", "@")
Expand Down
Loading

0 comments on commit d809e39

Please sign in to comment.