Skip to content

Commit

Permalink
Offline batch generation (#923)
Browse files Browse the repository at this point in the history
* Initial work for `offline_batch_generate`

* Add code for uploading Batch API files to OpenAI

* Add `offline_batch_inference` attribute

* `offline_batch_generate` finished for `OpenAILLM`

* Add attributes for checking task compatibility with
`offline_batch_generation`

* Extend `is_global` property for `offline_batch_generation`

* Move `job_ids` responsability to `LLM`

* And remember... `unload` everything before pickling

* Move `BASE_CACHE_DIR` to constants

* Recover for offline batch generation

* Polling sleep

* Store input data for recovering offline batch generation

* Lint

* Add checking no offline batch generation with `RayPipeline`

* Update `LLM`s unit tests

* Update `Task`s unit tests

* Add `OpenAILLM.offline_batch_generate` unit tests

* Fix unit test

* Add unit tests for adding recovery batch for offline generation

* Update tasks that can be used with offline batch generation

* Move aux functions to utils

* Handle `_SecretField` and excluded attributes when refreshing pipeline
from cache

* Fix checking inner type

* Add simple integration test

* Remove unit test

* Fix formatting exception

* Update type hint

* Handle stopping offline batch generation polling

* Use `_stop_called_lock` everywhere

* Fix deadlock

* Fix load

* Add Batch API example

* Update examples

* How to offline batch generation

* Add FAQ about OpenAI Batch API

* Update links

* Add `envs` module

* Add setting pipeline running env variables in child process

* Update OpenAI file upload to assign custom name

* Download nltk everytime

* Add missing arguments

* Update logging message

* Add section about offline batch generation

* Add errors and exceptions API docs

* Fix unit test

* Update mkdocs.yaml
  • Loading branch information
gabrielmbmb authored Sep 2, 2024
1 parent a2a8e86 commit 28485d0
Show file tree
Hide file tree
Showing 84 changed files with 2,077 additions and 350 deletions.
8 changes: 8 additions & 0 deletions docs/api/errors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Errors

This section contains the `distilabel` custom errors. Unlike [exceptions](exceptions.md), errors in `distilabel` are used to handle unexpected situations that can't be anticipated and that can't be handled in a controlled way.

:::distilabel.errors.DistilabelError
:::distilabel.errors.DistilabelUserError
:::distilabel.errors.DistilabelTypeError
:::distilabel.errors.DistilabelNotImplementedError
7 changes: 7 additions & 0 deletions docs/api/exceptions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Exceptions

This section contains the `distilabel` custom exceptions. Unlike [errors][../errors.md], exceptions in `distilabel` are used to handle specific situations that can be anticipated and that can be handled in a controlled way internally by the library.

:::distilabel.exceptions.DistilabelException
:::distilabel.exceptions.DistilabelGenerationException
:::distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException
3 changes: 3 additions & 0 deletions docs/sections/getting_started/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ hide:

??? faq "How can I use the same `LLM` across several tasks without having to load it several times?"
You can serve the LLM using a solution like TGI or vLLM, and then connect to it using an `AsyncLLM` client like `InferenceEndpointsLLM` or `OpenAILLM`. Please refer to [Serving LLMs guide](../how_to_guides/advanced/serving_an_llm_for_reuse.md) for more information.

??? faq "Can `distilabel` be used with [OpenAI Batch API](https://platform.openai.com/docs/guides/batch)?"
Yes, `distilabel` is integrated with OpenAI Batch API via [OpenAILLM][distilabel.llms.openai.OpenAILLM]. Check [LLMs - Offline Batch Generation](../how_to_guides/basic/llm/index.md#offline-batch-generation) for a small example on how to use it and [Advanced - Offline Batch Generation](../how_to_guides/advanced/offline_batch_generation.md) for a more detailed guide.
47 changes: 47 additions & 0 deletions docs/sections/how_to_guides/advanced/offline_batch_generation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
The [offline batch generation](../basic/llm/index.md#offline-batch-generation) is a feature that some `LLM`s implemented in `distilabel` offers, allowing to send the inputs to a LLM-as-a-service platform and waiting for the outputs in a asynchronous manner. LLM-as-a-service platforms offer this feature as it allows them to gather many inputs and creating batches as big as the hardware allows, maximizing the hardware utilization and reducing the cost of the service. In exchange, the user has to wait certain time for the outputs to be ready but the cost per token is usually much lower.

`distilabel` pipelines are able to handle `LLM`s that offer this feature in the following way:

* The first time the pipeline gets executed, the `LLM` will send the inputs to the platform. The platform will return jobs ids that can be used later to check the status of the jobs and retrieve the results. The `LLM` will save these jobs ids in its `jobs_ids` attribute and raise an special exception [DistilabelOfflineBatchGenerationNotFinishedException][distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException] that will be handled by the `Pipeline`. The jobs ids will be saved in the pipeline cache, so they can be used in subsequent calls.
* The second time and subsequent calls will recover the pipeline execution and the `LLM` won't send the inputs again to the platform. This time as it has the `jobs_ids` it will check if the jobs have finished, and if they have then it will retrieve the results and return the outputs. If they haven't finished, then it will raise again `DistilabelOfflineBatchGenerationNotFinishedException` again.
* In addition, LLMs with offline batch generation can be specified to do polling until the jobs have finished, blocking the pipeline until they are done. If for some reason the polling needs to be stopped, one can press ++ctrl+c++ or ++cmd+c++ depending on your OS (or send a `SIGINT` to the main process) which will stop the polling and raise `DistilabelOfflineBatchGenerationNotFinishedException` that will be handled by the pipeline as described above.

!!! WARNING

In order to recover the pipeline execution and retrieve the results, the pipeline cache must be enabled. If the pipeline cache is disabled, then it will send the inputs again and create different jobs incurring in extra costs.


## Example pipeline using `OpenAILLM` with offline batch generation

```python
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration

with Pipeline() as pipeline:
load_data = LoadDataFromHub(output_mappings={"prompt": "instruction"})

text_generation = TextGeneration(
llm=OpenAILLM(
model="gpt-3.5-turbo",
use_offline_batch_generation=True, # (1)
)
)

load_data >> text_generation


if __name__ == "__main__":
distiset = pipeline.run(
parameters={
load_data.name: {
"repo_id": "distilabel-internal-testing/instruction-dataset",
"split": "test",
"batch_size": 500,
},
}
)
```

1. Indicate that the `OpenAILLM` should use offline batch generation.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Saving step generated artifacts

Some `Step`s might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the [`FaissNearestNeighbour`](/distilabel/components-gallery/steps/faissnearestneighbour/) needs to create a Faiss index to compute the output of the step which are the top `k` nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the `distilabel` pipeline, so it would be a shame not saving it.
Some `Step`s might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the [`FaissNearestNeighbour`](../../../components-gallery/steps/faissnearestneighbour.md) needs to create a Faiss index to compute the output of the step which are the top `k` nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the `distilabel` pipeline, so it would be a shame not saving it.

For this reason, `Step`s have a method called `save_artifact` that allows saving artifacts that will be included along the outputs of the pipeline in the generated [`Distiset`][distilabel.distiset.Distiset]. The generated artifacts will be uploaded and saved when using `Distiset.push_to_hub` or `Distiset.save_to_disk` respectively. Let's see how to use it with a simple example.

Expand Down
2 changes: 1 addition & 1 deletion docs/sections/how_to_guides/advanced/scaling_with_ray.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ if __name__ == "__main__":
1. We're setting [resources](assigning_resources_to_step.md) for the `text_generation` step and defining that we want two replicas and one GPU per replica. `distilabel` will create two replicas of the step i.e. two actors in the Ray cluster, and each actor will request to be allocated in a node of the cluster that have at least one GPU. You can read more about how Ray manages the resources [here](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#resources).
2. You should modify this and add your user or organization on the Hugging Face Hub.

It's a basic pipeline with just two steps: one to load a dataset from the Hub with an `instruction` column and one to generate a `response` for that instruction using Llama 3 8B Instruct with [vLLM](/distilabel/components-gallery/llms/vllm/). Simple but enough to demonstrate how to distribute and scale the workload using a Ray cluster!
It's a basic pipeline with just two steps: one to load a dataset from the Hub with an `instruction` column and one to generate a `response` for that instruction using Llama 3 8B Instruct with [vLLM](../../../components-gallery/llms/vllm.md). Simple but enough to demonstrate how to distribute and scale the workload using a Ray cluster!

### Using Ray Jobs API

Expand Down
73 changes: 68 additions & 5 deletions docs/sections/how_to_guides/basic/llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Task], but they can also be used standalone.

```python
from distilabel.llms import OpenAILLM
from distilabel.llms import InferenceEndpointsLLM

llm = OpenAILLM(model="gpt-4")
llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct")
llm.load()

llm.generate(
llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
Expand All @@ -21,6 +21,69 @@ llm.generate(
!!! NOTE
Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`.

### Offline Batch Generation

By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.

To use this feature in `distilabel` the only thing we need to do is to set the `use_offline_batch_generation` attribute to `True` when creating the `LLM` instance:

```python
from distilabel.llms import OpenAILLM

llm = OpenAILLM(
model="gpt-4o",
use_offline_batch_generation=True,
)

llm.load()

llm.jobs_ids # (1)
# None

llm.generate_outputs( # (2)
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# DistilabelOfflineBatchGenerationNotFinishedException: Batch generation with jobs_ids=('batch_OGB4VjKpu2ay9nz3iiFJxt5H',) is not finished

llm.jobs_ids # (3)
# ('batch_OGB4VjKpu2ay9nz3iiFJxt5H',)


llm.generate_outputs( # (4)
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
```

1. At first the `jobs_ids` attribute is `None`.
2. The first call to `generate_outputs` will send the inputs to the LLM-as-a-service and return a `DistilabelOfflineBatchGenerationNotFinishedException` since the outputs are not ready yet.
3. After the first call to `generate_outputs` the `jobs_ids` attribute will contain the job ids created for generating the outputs.
4. The second call or subsequent calls to `generate_outputs` will return the outputs if they are ready or raise a `DistilabelOfflineBatchGenerationNotFinishedException` if they are not ready yet.

The `offline_batch_generation_block_until_done` attribute can be used to block the `generate_outputs` method until the outputs are ready polling the platform the specified amount of seconds.

```python
from distilabel.llms import OpenAILLM

llm = OpenAILLM(
model="gpt-4o",
use_offline_batch_generation=True,
offline_batch_generation_block_until_done=5, # poll for results every 5 seconds
)
llm.load()

llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
```

### Within a Task

Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and the task will handle the rest.
Expand Down Expand Up @@ -81,7 +144,7 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
* `generate`: A method that takes a list of prompts and returns generated texts.

* `agenerate`: A method that takes a single prompt and returns generated texts. This method is used within the `generate` method of the `AsyncLLM` class.
*

* (optional) `get_last_hidden_state`: is a method that will take a list of prompts and return a list of hidden states. This method is optional and will be used by some tasks such as the [`GenerateEmbeddings`][distilabel.steps.tasks.GenerateEmbeddings] task.


Expand Down Expand Up @@ -142,4 +205,4 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron

## Available LLMs

[Our LLM gallery](/distilabel/components-gallery/llms/) shows a list of the available LLMs that can be used within the `distilabel` library.
[Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library.
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ nav:
- Cache and recover pipeline executions: "sections/how_to_guides/advanced/caching.md"
- Export data to Argilla: "sections/how_to_guides/advanced/argilla.md"
- Structured data generation: "sections/how_to_guides/advanced/structured_generation.md"
- Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md"
- Specify requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md"
- Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md"
- Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md"
Expand Down Expand Up @@ -243,6 +244,8 @@ nav:
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
- Exceptions: "api/exceptions.md"
- Errors: "api/errors.md"
- Distiset: "api/distiset.md"
- CLI: "api/cli.md"
- Community:
Expand Down
3 changes: 0 additions & 3 deletions scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ python -m pip install uv

uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash]"

# For the tests of minhash
python -c "import nltk; nltk.download('punkt_tab')"

if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
fi
Expand Down
24 changes: 22 additions & 2 deletions src/distilabel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Final

# Steps related constants
DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata"

# Pipeline related constants
# Cache
BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel"
PIPELINES_CACHE_DIR = BASE_CACHE_DIR / "pipelines"

# Pipeline dag related constants
STEP_ATTR_NAME: Final[str] = "step"
INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function"
CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"

# Pipeline execution related constants
PIPELINE_NAME_ENV_NAME = "DISTILABEL_PIPELINE_NAME"
PIPELINE_CACHE_ID_ENV_NAME = "DISTILABEL_PIPELINE_CACHE_ID"
SIGINT_HANDLER_CALLED_ENV_NAME = "sigint_handler_called"

# Data paths constants
STEPS_OUTPUTS_PATH = "steps_outputs"
STEPS_ARTIFACTS_PATH = "steps_artifacts"
Expand All @@ -40,11 +50,21 @@


__all__ = [
"DISTILABEL_METADATA_KEY",
"BASE_CACHE_DIR",
"PIPELINES_CACHE_DIR",
"STEP_ATTR_NAME",
"INPUT_QUEUE_ATTR_NAME",
"RECEIVES_ROUTED_BATCHES_ATTR_NAME",
"ROUTING_BATCH_FUNCTION_ATTR_NAME",
"CONVERGENCE_STEP_ATTR_NAME",
"LAST_BATCH_SENT_FLAG",
"DISTILABEL_METADATA_KEY",
"SIGINT_HANDLER_CALLED_ENV_NAME",
"STEPS_OUTPUTS_PATH",
"STEPS_ARTIFACTS_PATH",
"DISTISET_CONFIG_FOLDER",
"DISTISET_ARTIFACTS_FOLDER",
"PIPELINE_CONFIG_FILENAME",
"PIPELINE_LOG_FILENAME",
"DISTILABEL_DOCS_URL",
]
52 changes: 52 additions & 0 deletions src/distilabel/envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Idea from: https://github.com/vllm-project/vllm/blob/main/vllm/envs.py

import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

from distilabel import constants

if TYPE_CHECKING:
DISTILABEL_LOG_LEVEL: str = "INFO"
DISTILABEL_PIPELINE_NAME: Optional[str] = None
DISTILABEL_PIPELINE_CACHE_ID: Optional[str] = None
DISTILABEL_CACHE_DIR: Optional[str] = None

ENVIRONMENT_VARIABLES: Dict[str, Callable[[], Any]] = {
# `distilabel` logging level.
"DISTILABEL_LOG_LEVEL": lambda: os.getenv("DISTILABEL_LOG_LEVEL", "INFO").upper(),
# The name of the `distilabel` pipeline currently running.
constants.PIPELINE_NAME_ENV_NAME: lambda: os.getenv(
constants.PIPELINE_NAME_ENV_NAME, None
),
# The cache ID of the `distilabel` pipeline currently running.
constants.PIPELINE_CACHE_ID_ENV_NAME: lambda: os.getenv(
constants.PIPELINE_CACHE_ID_ENV_NAME, None
),
# The cache ID of the `distilabel` pipeline currently running.
"DISTILABEL_CACHE_DIR": lambda: os.getenv("DISTILABEL_CACHE_DIR", None),
}


def __getattr__(name: str) -> Any:
# lazy evaluation of environment variables
if name in ENVIRONMENT_VARIABLES:
return ENVIRONMENT_VARIABLES[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> List[str]:
return list(ENVIRONMENT_VARIABLES.keys())
6 changes: 6 additions & 0 deletions src/distilabel/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ class DistilabelTypeError(DistilabelError, TypeError):
"""TypeError that we can redirect to a given page in the documentation."""

pass


class DistilabelNotImplementedError(DistilabelError, NotImplementedError):
"""NotImplementedError that we can redirect to a given page in the documentation."""

pass
40 changes: 40 additions & 0 deletions src/distilabel/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Tuple


class DistilabelException(Exception):
"""Base exception (can be gracefully handled) for `distilabel` framework."""

pass


class DistilabelGenerationException(DistilabelException):
"""Base exception for `LLM` generation errors."""

pass


class DistilabelOfflineBatchGenerationNotFinishedException(
DistilabelGenerationException
):
"""Exception raised when a batch generation is not finished."""

jobs_ids: Tuple[str, ...]

def __init__(self, jobs_ids: Tuple[str, ...]) -> None:
self.jobs_ids = jobs_ids
super().__init__(f"Batch generation with jobs_ids={jobs_ids} is not finished")
8 changes: 2 additions & 6 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ class AnthropicLLM(AsyncLLM):
llm.load()
# Synchronous request
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
# Asynchronous request
output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
Expand All @@ -110,7 +106,7 @@ class User(BaseModel):
llm.load()
output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""

Expand Down
Loading

0 comments on commit 28485d0

Please sign in to comment.