-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial work for `offline_batch_generate` * Add code for uploading Batch API files to OpenAI * Add `offline_batch_inference` attribute * `offline_batch_generate` finished for `OpenAILLM` * Add attributes for checking task compatibility with `offline_batch_generation` * Extend `is_global` property for `offline_batch_generation` * Move `job_ids` responsability to `LLM` * And remember... `unload` everything before pickling * Move `BASE_CACHE_DIR` to constants * Recover for offline batch generation * Polling sleep * Store input data for recovering offline batch generation * Lint * Add checking no offline batch generation with `RayPipeline` * Update `LLM`s unit tests * Update `Task`s unit tests * Add `OpenAILLM.offline_batch_generate` unit tests * Fix unit test * Add unit tests for adding recovery batch for offline generation * Update tasks that can be used with offline batch generation * Move aux functions to utils * Handle `_SecretField` and excluded attributes when refreshing pipeline from cache * Fix checking inner type * Add simple integration test * Remove unit test * Fix formatting exception * Update type hint * Handle stopping offline batch generation polling * Use `_stop_called_lock` everywhere * Fix deadlock * Fix load * Add Batch API example * Update examples * How to offline batch generation * Add FAQ about OpenAI Batch API * Update links * Add `envs` module * Add setting pipeline running env variables in child process * Update OpenAI file upload to assign custom name * Download nltk everytime * Add missing arguments * Update logging message * Add section about offline batch generation * Add errors and exceptions API docs * Fix unit test * Update mkdocs.yaml
- Loading branch information
1 parent
a2a8e86
commit 28485d0
Showing
84 changed files
with
2,077 additions
and
350 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Errors | ||
|
||
This section contains the `distilabel` custom errors. Unlike [exceptions](exceptions.md), errors in `distilabel` are used to handle unexpected situations that can't be anticipated and that can't be handled in a controlled way. | ||
|
||
:::distilabel.errors.DistilabelError | ||
:::distilabel.errors.DistilabelUserError | ||
:::distilabel.errors.DistilabelTypeError | ||
:::distilabel.errors.DistilabelNotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Exceptions | ||
|
||
This section contains the `distilabel` custom exceptions. Unlike [errors][../errors.md], exceptions in `distilabel` are used to handle specific situations that can be anticipated and that can be handled in a controlled way internally by the library. | ||
|
||
:::distilabel.exceptions.DistilabelException | ||
:::distilabel.exceptions.DistilabelGenerationException | ||
:::distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
docs/sections/how_to_guides/advanced/offline_batch_generation.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
The [offline batch generation](../basic/llm/index.md#offline-batch-generation) is a feature that some `LLM`s implemented in `distilabel` offers, allowing to send the inputs to a LLM-as-a-service platform and waiting for the outputs in a asynchronous manner. LLM-as-a-service platforms offer this feature as it allows them to gather many inputs and creating batches as big as the hardware allows, maximizing the hardware utilization and reducing the cost of the service. In exchange, the user has to wait certain time for the outputs to be ready but the cost per token is usually much lower. | ||
|
||
`distilabel` pipelines are able to handle `LLM`s that offer this feature in the following way: | ||
|
||
* The first time the pipeline gets executed, the `LLM` will send the inputs to the platform. The platform will return jobs ids that can be used later to check the status of the jobs and retrieve the results. The `LLM` will save these jobs ids in its `jobs_ids` attribute and raise an special exception [DistilabelOfflineBatchGenerationNotFinishedException][distilabel.exceptions.DistilabelOfflineBatchGenerationNotFinishedException] that will be handled by the `Pipeline`. The jobs ids will be saved in the pipeline cache, so they can be used in subsequent calls. | ||
* The second time and subsequent calls will recover the pipeline execution and the `LLM` won't send the inputs again to the platform. This time as it has the `jobs_ids` it will check if the jobs have finished, and if they have then it will retrieve the results and return the outputs. If they haven't finished, then it will raise again `DistilabelOfflineBatchGenerationNotFinishedException` again. | ||
* In addition, LLMs with offline batch generation can be specified to do polling until the jobs have finished, blocking the pipeline until they are done. If for some reason the polling needs to be stopped, one can press ++ctrl+c++ or ++cmd+c++ depending on your OS (or send a `SIGINT` to the main process) which will stop the polling and raise `DistilabelOfflineBatchGenerationNotFinishedException` that will be handled by the pipeline as described above. | ||
|
||
!!! WARNING | ||
|
||
In order to recover the pipeline execution and retrieve the results, the pipeline cache must be enabled. If the pipeline cache is disabled, then it will send the inputs again and create different jobs incurring in extra costs. | ||
|
||
|
||
## Example pipeline using `OpenAILLM` with offline batch generation | ||
|
||
```python | ||
from distilabel.llms import OpenAILLM | ||
from distilabel.pipeline import Pipeline | ||
from distilabel.steps import LoadDataFromHub | ||
from distilabel.steps.tasks import TextGeneration | ||
|
||
with Pipeline() as pipeline: | ||
load_data = LoadDataFromHub(output_mappings={"prompt": "instruction"}) | ||
|
||
text_generation = TextGeneration( | ||
llm=OpenAILLM( | ||
model="gpt-3.5-turbo", | ||
use_offline_batch_generation=True, # (1) | ||
) | ||
) | ||
|
||
load_data >> text_generation | ||
|
||
|
||
if __name__ == "__main__": | ||
distiset = pipeline.run( | ||
parameters={ | ||
load_data.name: { | ||
"repo_id": "distilabel-internal-testing/instruction-dataset", | ||
"split": "test", | ||
"batch_size": 500, | ||
}, | ||
} | ||
) | ||
``` | ||
|
||
1. Indicate that the `OpenAILLM` should use offline batch generation. |
2 changes: 1 addition & 1 deletion
2
docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Idea from: https://github.com/vllm-project/vllm/blob/main/vllm/envs.py | ||
|
||
import os | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional | ||
|
||
from distilabel import constants | ||
|
||
if TYPE_CHECKING: | ||
DISTILABEL_LOG_LEVEL: str = "INFO" | ||
DISTILABEL_PIPELINE_NAME: Optional[str] = None | ||
DISTILABEL_PIPELINE_CACHE_ID: Optional[str] = None | ||
DISTILABEL_CACHE_DIR: Optional[str] = None | ||
|
||
ENVIRONMENT_VARIABLES: Dict[str, Callable[[], Any]] = { | ||
# `distilabel` logging level. | ||
"DISTILABEL_LOG_LEVEL": lambda: os.getenv("DISTILABEL_LOG_LEVEL", "INFO").upper(), | ||
# The name of the `distilabel` pipeline currently running. | ||
constants.PIPELINE_NAME_ENV_NAME: lambda: os.getenv( | ||
constants.PIPELINE_NAME_ENV_NAME, None | ||
), | ||
# The cache ID of the `distilabel` pipeline currently running. | ||
constants.PIPELINE_CACHE_ID_ENV_NAME: lambda: os.getenv( | ||
constants.PIPELINE_CACHE_ID_ENV_NAME, None | ||
), | ||
# The cache ID of the `distilabel` pipeline currently running. | ||
"DISTILABEL_CACHE_DIR": lambda: os.getenv("DISTILABEL_CACHE_DIR", None), | ||
} | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
# lazy evaluation of environment variables | ||
if name in ENVIRONMENT_VARIABLES: | ||
return ENVIRONMENT_VARIABLES[name]() | ||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") | ||
|
||
|
||
def __dir__() -> List[str]: | ||
return list(ENVIRONMENT_VARIABLES.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Tuple | ||
|
||
|
||
class DistilabelException(Exception): | ||
"""Base exception (can be gracefully handled) for `distilabel` framework.""" | ||
|
||
pass | ||
|
||
|
||
class DistilabelGenerationException(DistilabelException): | ||
"""Base exception for `LLM` generation errors.""" | ||
|
||
pass | ||
|
||
|
||
class DistilabelOfflineBatchGenerationNotFinishedException( | ||
DistilabelGenerationException | ||
): | ||
"""Exception raised when a batch generation is not finished.""" | ||
|
||
jobs_ids: Tuple[str, ...] | ||
|
||
def __init__(self, jobs_ids: Tuple[str, ...]) -> None: | ||
self.jobs_ids = jobs_ids | ||
super().__init__(f"Batch generation with jobs_ids={jobs_ids} is not finished") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.