Skip to content

Commit e6c9d9e

Browse files
Fix Image import handling and update MlxLLM initialisation (#1102)
Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
1 parent d04f069 commit e6c9d9e

File tree

8 files changed

+24
-15
lines changed

8 files changed

+24
-15
lines changed

Diff for: pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"orjson >= 3.10.0",
3939
"universal_pathlib >= 0.2.2",
4040
"portalocker >= 2.8.2",
41+
"setuptools",
4142
]
4243
dynamic = ["version"]
4344

@@ -90,9 +91,7 @@ ray = ["ray[default] >= 2.31.0"]
9091
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
9192
vllm = [
9293
"vllm >= 0.5.3",
93-
"filelock >= 3.13.4",
94-
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
95-
"setuptools",
94+
"filelock >= 3.13.4"
9695
]
9796
sentence-transformers = ["sentence-transformers >= 3.0.0"]
9897
faiss-cpu = ["faiss-cpu >= 1.8.0"]

Diff for: src/distilabel/models/image_generation/huggingface/inference_endpoints.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
InferenceEndpointsBaseClient,
2121
)
2222
from distilabel.models.image_generation.base import AsyncImageGenerationModel
23-
from distilabel.models.image_generation.utils import image_to_str
2423

2524
if TYPE_CHECKING:
2625
from PIL.Image import Image
@@ -60,10 +59,14 @@ class InferenceEndpointsImageGeneration( # type: ignore
6059
"""
6160

6261
def load(self) -> None:
62+
from distilabel.models.image_generation.utils import image_to_str
63+
6364
# Sets the logger and calls the load method of the BaseClient
6465
AsyncImageGenerationModel.load(self)
6566
InferenceEndpointsBaseClient.load(self)
6667

68+
self._image_to_str = image_to_str
69+
6770
@validate_call
6871
async def agenerate( # type: ignore
6972
self,
@@ -101,6 +104,6 @@ async def agenerate( # type: ignore
101104
num_inference_steps=num_inference_steps,
102105
guidance_scale=guidance_scale,
103106
)
104-
img_str = image_to_str(image, image_format="JPEG")
107+
img_str = self._image_to_str(image, image_format="JPEG")
105108

106109
return [{"images": [img_str]}]

Diff for: src/distilabel/models/image_generation/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from PIL import Image
1919

2020

21-
def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str:
21+
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
2222
"""Converts a PIL Image to a base64 encoded string."""
2323
buffered = io.BytesIO()
2424
image.save(buffered, format=image_format)
2525
return base64.b64encode(buffered.getvalue()).decode("utf-8")
2626

2727

28-
def image_from_str(image_str: str) -> Image.Image:
28+
def image_from_str(image_str: str) -> "Image.Image":
2929
"""Converts a base64 encoded string to a PIL Image."""
3030
image_bytes = base64.b64decode(image_str)
3131
return Image.open(io.BytesIO(image_bytes))

Diff for: src/distilabel/models/llms/mlx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
6060
```python
6161
from distilabel.models.llms import MlxLLM
6262
63-
llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
63+
llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
6464
6565
llm.load()
6666

Diff for: src/distilabel/steps/tasks/image_generation.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import hashlib
1616
from typing import TYPE_CHECKING
1717

18-
from distilabel.models.image_generation.utils import image_from_str
1918
from distilabel.steps.base import StepInput
2019
from distilabel.steps.tasks.base import ImageTask
2120

@@ -117,6 +116,13 @@ class ImageGeneration(ImageTask):
117116
save_artifacts: bool = False
118117
image_format: str = "JPEG"
119118

119+
def load(self) -> None:
120+
from distilabel.models.image_generation.utils import image_from_str
121+
122+
super().load()
123+
124+
self._image_from_str = image_from_str
125+
120126
@property
121127
def inputs(self) -> "StepColumns":
122128
return ["prompt"]
@@ -166,7 +172,7 @@ def process(self, inputs: StepInput) -> "StepOutput":
166172
# use prompt as filename
167173
prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest()
168174
# Build PIL image to save it
169-
image = image_from_str(image)
175+
image = self._image_from_str(image)
170176

171177
self.save_artifact(
172178
name="images",

Diff for: src/distilabel/steps/tasks/structured_outputs/outlines.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
get_args,
2929
)
3030

31-
import pkg_resources
3231
from pydantic import BaseModel
3332

3433
from distilabel.errors import DistilabelUserError
@@ -50,6 +49,8 @@ def _is_outlines_version_below_0_1_0() -> bool:
5049
Returns:
5150
bool: True if outlines is not installed or version is below 0.1.0
5251
"""
52+
import pkg_resources
53+
5354
if not importlib.util.find_spec("outlines"):
5455
raise ImportError(
5556
"Outlines is not installed. Please install it using `pip install outlines`."

Diff for: src/distilabel/steps/tasks/text_generation_with_image.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import TYPE_CHECKING, Any, Literal, Union
1616

1717
from jinja2 import Template
18-
from PIL import Image
1918
from pydantic import Field
2019

2120
from distilabel.steps.tasks.base import Task

Diff for: src/distilabel/utils/image.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
import base64
1616
import io
17+
from typing import TYPE_CHECKING
1718

18-
from PIL import Image
19+
if TYPE_CHECKING:
20+
from PIL import Image
1921

2022

21-
# TODO: Once we merge the image generation, this function can be reused
22-
def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str:
23+
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
2324
"""Converts a PIL Image to a base64 encoded string."""
2425
buffered = io.BytesIO()
2526
image.save(buffered, format=image_format)

0 commit comments

Comments
 (0)