Skip to content

Commit

Permalink
Add new add_raw_input argument to _Task so we can automatically i…
Browse files Browse the repository at this point in the history
…nclude the formatted input (#903)

* Add attribute to include raw formatted input to distilabel_metadata field

* Update tests to take into account add_raw_input attribute of tasks

* Add reference to add_raw_input in the documentation

* Update tests to control for the add_raw_input of the _Task
  • Loading branch information
plaguss authored Aug 14, 2024
1 parent c8df5a9 commit 3d772c5
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 29 deletions.
26 changes: 24 additions & 2 deletions docs/sections/how_to_guides/basic/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
# {
# 'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'},
# 'distilabel_metadata': {
# 'raw_output_text-generation': 'The capital of Spain is Madrid.',
# 'raw_input_text-generation': [
# {'role': 'user', 'content': "What's the capital of Spain?"}
# ]
# },
# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct'
# }
# ]
Expand All @@ -33,7 +38,24 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
!!! NOTE
The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution.

As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`. Additionally, it provides some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.
As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`.

!!! Tip
Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.

Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing
custom templates (testing the pipeline using the [`dry_run`][distilabel.pipeline.local.Pipeline.dry_run] method).

```python title="disable raw input and output"
task = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
add_raw_output=False,
add_raw_input=False
)
```

## Specifying the number of generations and grouping generations

Expand Down
32 changes: 26 additions & 6 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class _Task(_Step, ABC):
" of the `distilabel_metadata` dictionary output column"
),
)
add_raw_input: RuntimeParameter[bool] = Field(
default=True,
description=(
"Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary column"
),
)
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
Expand Down Expand Up @@ -113,10 +120,12 @@ def _format_outputs(
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
try:
formatted_output = self.format_output(output, input)
formatted_output = self._maybe_add_raw_output(
formatted_output = self._maybe_add_raw_input_output(
formatted_output,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
)
formatted_outputs.append(formatted_output)
except Exception as e:
Expand All @@ -135,24 +144,35 @@ def _output_on_failure(
# Create a dictionary with the outputs of the task (every output set to None)
outputs = {output: None for output in self.outputs}
outputs["model_name"] = self.llm.model_name # type: ignore
outputs = self._maybe_add_raw_output(
outputs = self._maybe_add_raw_input_output(
outputs,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
)
return outputs

def _maybe_add_raw_output(
def _maybe_add_raw_input_output(
self,
output: Dict[str, Any],
raw_output: Union[str, None],
input: Union[str, None],
add_raw_output: bool = True,
) -> Dict[str, Any]:
"""Adds the raw output of the LLM to the output dictionary if `add_raw_output` is True."""
add_raw_input: bool = True,
):
"""Adds the raw output and or the formatted input of the LLM to the output dictionary
if `add_raw_output` is True or `add_raw_input` is True.
"""
meta = output.get(DISTILABEL_METADATA_KEY, {})

if add_raw_output:
meta = output.get(DISTILABEL_METADATA_KEY, {})
meta[f"raw_output_{self.name}"] = raw_output
if add_raw_input:
meta[f"raw_input_{self.name}"] = self.format_input(input)
if meta:
output[DISTILABEL_METADATA_KEY] = meta

return output

def _set_default_structured_output(self) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
assert task.dump() == {
"name": "task",
"add_raw_output": True,
"add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
Expand Down Expand Up @@ -206,6 +207,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
{
"description": "Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>` of the `distilabel_metadata` dictionary column",
"name": "add_raw_input",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
},
},
"add_raw_output": True,
"add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
Expand Down Expand Up @@ -201,6 +202,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
{
"description": "Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>` of the `distilabel_metadata` dictionary column",
"name": "add_raw_input",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/evol_quality/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
assert task.dump() == {
"name": "task",
"add_raw_output": True,
"add_raw_input": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"resources": {
Expand Down Expand Up @@ -170,6 +171,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "add_raw_output",
"optional": True,
},
{
"description": "Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>` of the `distilabel_metadata` dictionary column",
"name": "add_raw_input",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/magpie/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def test_serialization(self) -> None:
"input_batch_size": 50,
"group_generations": False,
"add_raw_output": True,
"add_raw_input": True,
"num_generations": 1,
"use_default_structured_output": False,
"runtime_parameters_info": [
Expand Down Expand Up @@ -500,6 +501,11 @@ def test_serialization(self) -> None:
"optional": True,
"description": "Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>` of the `distilabel_metadata` dictionary output column",
},
{
"description": "Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>` of the `distilabel_metadata` dictionary column",
"name": "add_raw_input",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/magpie/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_serialization(self) -> None:
"batch_size": 50,
"group_generations": False,
"add_raw_output": True,
"add_raw_input": True,
"num_generations": 1,
"num_rows": None,
"use_default_structured_output": False,
Expand Down Expand Up @@ -157,6 +158,11 @@ def test_serialization(self) -> None:
"optional": True,
"description": "Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>` of the `distilabel_metadata` dictionary output column",
},
{
"description": "Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>` of the `distilabel_metadata` dictionary column",
"name": "add_raw_input",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
Loading

0 comments on commit 3d772c5

Please sign in to comment.