Skip to content

Commit

Permalink
Fix Image import handling and update MlxLLM initialisation (#1102)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
  • Loading branch information
davidberenstein1957 and gabrielmbmb authored Jan 17, 2025
1 parent d04f069 commit e6c9d9e
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 15 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"orjson >= 3.10.0",
"universal_pathlib >= 0.2.2",
"portalocker >= 2.8.2",
"setuptools",
]
dynamic = ["version"]

Expand Down Expand Up @@ -90,9 +91,7 @@ ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
"vllm >= 0.5.3",
"filelock >= 3.13.4",
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
"setuptools",
"filelock >= 3.13.4"
]
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
InferenceEndpointsBaseClient,
)
from distilabel.models.image_generation.base import AsyncImageGenerationModel
from distilabel.models.image_generation.utils import image_to_str

if TYPE_CHECKING:
from PIL.Image import Image
Expand Down Expand Up @@ -60,10 +59,14 @@ class InferenceEndpointsImageGeneration( # type: ignore
"""

def load(self) -> None:
from distilabel.models.image_generation.utils import image_to_str

# Sets the logger and calls the load method of the BaseClient
AsyncImageGenerationModel.load(self)
InferenceEndpointsBaseClient.load(self)

self._image_to_str = image_to_str

@validate_call
async def agenerate( # type: ignore
self,
Expand Down Expand Up @@ -101,6 +104,6 @@ async def agenerate( # type: ignore
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
img_str = image_to_str(image, image_format="JPEG")
img_str = self._image_to_str(image, image_format="JPEG")

return [{"images": [img_str]}]
4 changes: 2 additions & 2 deletions src/distilabel/models/image_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from PIL import Image


def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str:
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def image_from_str(image_str: str) -> Image.Image:
def image_from_str(image_str: str) -> "Image.Image":
"""Converts a base64 encoded string to a PIL Image."""
image_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(image_bytes))
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
```python
from distilabel.models.llms import MlxLLM
llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
llm.load()
Expand Down
10 changes: 8 additions & 2 deletions src/distilabel/steps/tasks/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import hashlib
from typing import TYPE_CHECKING

from distilabel.models.image_generation.utils import image_from_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask

Expand Down Expand Up @@ -117,6 +116,13 @@ class ImageGeneration(ImageTask):
save_artifacts: bool = False
image_format: str = "JPEG"

def load(self) -> None:
from distilabel.models.image_generation.utils import image_from_str

super().load()

self._image_from_str = image_from_str

@property
def inputs(self) -> "StepColumns":
return ["prompt"]
Expand Down Expand Up @@ -166,7 +172,7 @@ def process(self, inputs: StepInput) -> "StepOutput":
# use prompt as filename
prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest()
# Build PIL image to save it
image = image_from_str(image)
image = self._image_from_str(image)

self.save_artifact(
name="images",
Expand Down
3 changes: 2 additions & 1 deletion src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_args,
)

import pkg_resources
from pydantic import BaseModel

from distilabel.errors import DistilabelUserError
Expand All @@ -50,6 +49,8 @@ def _is_outlines_version_below_0_1_0() -> bool:
Returns:
bool: True if outlines is not installed or version is below 0.1.0
"""
import pkg_resources

if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
Expand Down
1 change: 0 additions & 1 deletion src/distilabel/steps/tasks/text_generation_with_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import TYPE_CHECKING, Any, Literal, Union

from jinja2 import Template
from PIL import Image
from pydantic import Field

from distilabel.steps.tasks.base import Task
Expand Down
7 changes: 4 additions & 3 deletions src/distilabel/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import base64
import io
from typing import TYPE_CHECKING

from PIL import Image
if TYPE_CHECKING:
from PIL import Image


# TODO: Once we merge the image generation, this function can be reused
def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str:
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
Expand Down

0 comments on commit e6c9d9e

Please sign in to comment.