From e6c9d9e2cd9f2a853fbb2d8759a3fb86f4931656 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 17 Jan 2025 13:13:48 +0100 Subject: [PATCH] Fix `Image` import handling and update `MlxLLM` initialisation (#1102) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- pyproject.toml | 5 ++--- .../huggingface/inference_endpoints.py | 7 +++++-- src/distilabel/models/image_generation/utils.py | 4 ++-- src/distilabel/models/llms/mlx.py | 2 +- src/distilabel/steps/tasks/image_generation.py | 10 ++++++++-- .../steps/tasks/structured_outputs/outlines.py | 3 ++- .../steps/tasks/text_generation_with_image.py | 1 - src/distilabel/utils/image.py | 7 ++++--- 8 files changed, 24 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c55ebb1c..a413f05e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "orjson >= 3.10.0", "universal_pathlib >= 0.2.2", "portalocker >= 2.8.2", + "setuptools", ] dynamic = ["version"] @@ -90,9 +91,7 @@ ray = ["ray[default] >= 2.31.0"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] vllm = [ "vllm >= 0.5.3", - "filelock >= 3.13.4", - # `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]` - "setuptools", + "filelock >= 3.13.4" ] sentence-transformers = ["sentence-transformers >= 3.0.0"] faiss-cpu = ["faiss-cpu >= 1.8.0"] diff --git a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py index 2403fbf01..a5225815e 100644 --- a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py +++ b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py @@ -20,7 +20,6 @@ InferenceEndpointsBaseClient, ) from distilabel.models.image_generation.base import AsyncImageGenerationModel -from distilabel.models.image_generation.utils import image_to_str if TYPE_CHECKING: from PIL.Image import Image @@ -60,10 +59,14 @@ class InferenceEndpointsImageGeneration( # type: ignore """ def load(self) -> None: + from distilabel.models.image_generation.utils import image_to_str + # Sets the logger and calls the load method of the BaseClient AsyncImageGenerationModel.load(self) InferenceEndpointsBaseClient.load(self) + self._image_to_str = image_to_str + @validate_call async def agenerate( # type: ignore self, @@ -101,6 +104,6 @@ async def agenerate( # type: ignore num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) - img_str = image_to_str(image, image_format="JPEG") + img_str = self._image_to_str(image, image_format="JPEG") return [{"images": [img_str]}] diff --git a/src/distilabel/models/image_generation/utils.py b/src/distilabel/models/image_generation/utils.py index e5f08ca34..fe7e4e5d7 100644 --- a/src/distilabel/models/image_generation/utils.py +++ b/src/distilabel/models/image_generation/utils.py @@ -18,14 +18,14 @@ from PIL import Image -def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: +def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str: """Converts a PIL Image to a base64 encoded string.""" buffered = io.BytesIO() image.save(buffered, format=image_format) return base64.b64encode(buffered.getvalue()).decode("utf-8") -def image_from_str(image_str: str) -> Image.Image: +def image_from_str(image_str: str) -> "Image.Image": """Converts a base64 encoded string to a PIL Image.""" image_bytes = base64.b64decode(image_str) return Image.open(io.BytesIO(image_bytes)) diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index ffdcf3752..e23401b07 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -60,7 +60,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): ```python from distilabel.models.llms import MlxLLM - llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit") + llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit") llm.load() diff --git a/src/distilabel/steps/tasks/image_generation.py b/src/distilabel/steps/tasks/image_generation.py index 3484b9005..ebc411b8c 100644 --- a/src/distilabel/steps/tasks/image_generation.py +++ b/src/distilabel/steps/tasks/image_generation.py @@ -15,7 +15,6 @@ import hashlib from typing import TYPE_CHECKING -from distilabel.models.image_generation.utils import image_from_str from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import ImageTask @@ -117,6 +116,13 @@ class ImageGeneration(ImageTask): save_artifacts: bool = False image_format: str = "JPEG" + def load(self) -> None: + from distilabel.models.image_generation.utils import image_from_str + + super().load() + + self._image_from_str = image_from_str + @property def inputs(self) -> "StepColumns": return ["prompt"] @@ -166,7 +172,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # use prompt as filename prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest() # Build PIL image to save it - image = image_from_str(image) + image = self._image_from_str(image) self.save_artifact( name="images", diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index 45b5fe749..a0b4ced55 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -28,7 +28,6 @@ get_args, ) -import pkg_resources from pydantic import BaseModel from distilabel.errors import DistilabelUserError @@ -50,6 +49,8 @@ def _is_outlines_version_below_0_1_0() -> bool: Returns: bool: True if outlines is not installed or version is below 0.1.0 """ + import pkg_resources + if not importlib.util.find_spec("outlines"): raise ImportError( "Outlines is not installed. Please install it using `pip install outlines`." diff --git a/src/distilabel/steps/tasks/text_generation_with_image.py b/src/distilabel/steps/tasks/text_generation_with_image.py index 8aee386f8..3e0bef56e 100644 --- a/src/distilabel/steps/tasks/text_generation_with_image.py +++ b/src/distilabel/steps/tasks/text_generation_with_image.py @@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, Any, Literal, Union from jinja2 import Template -from PIL import Image from pydantic import Field from distilabel.steps.tasks.base import Task diff --git a/src/distilabel/utils/image.py b/src/distilabel/utils/image.py index aa9d09089..060eb71e0 100644 --- a/src/distilabel/utils/image.py +++ b/src/distilabel/utils/image.py @@ -14,12 +14,13 @@ import base64 import io +from typing import TYPE_CHECKING -from PIL import Image +if TYPE_CHECKING: + from PIL import Image -# TODO: Once we merge the image generation, this function can be reused -def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: +def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str: """Converts a PIL Image to a base64 encoded string.""" buffered = io.BytesIO() image.save(buffered, format=image_format)