Skip to content

Image Language Models and ImageGeneration task #1060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
1f5e271
Add PIL for image processing
plaguss Nov 13, 2024
4fa1c10
Add module to store vision language models
plaguss Nov 13, 2024
b2d858a
First version of text-to-image with inference endpoints
plaguss Nov 13, 2024
4733c7d
Add text-to-image with OpenAI
plaguss Nov 13, 2024
6164b3c
Add image generation task
plaguss Nov 14, 2024
b201baf
Redirect imports
plaguss Nov 14, 2024
88b8d51
Redirect imports
plaguss Nov 14, 2024
06db4e2
Add image-generation icon
plaguss Nov 14, 2024
6fe997a
Add vision language models
plaguss Nov 14, 2024
18cd75b
Move vlms to ilms to make it the name more explicit
plaguss Nov 14, 2024
e8cfac5
Update vlms to ilms
plaguss Nov 14, 2024
727d5aa
Add image language models to the components gallery
plaguss Nov 14, 2024
0c97eeb
Refactor ILM and fix image saves when save_images=False
plaguss Nov 14, 2024
0aaec8c
Add example
plaguss Nov 15, 2024
00af941
Update task to work saving images as JPEG artifact and raw base64 string
plaguss Nov 15, 2024
5629464
Add short tutorial example for image generation
plaguss Nov 15, 2024
ffed25e
Update examples with correct output format
plaguss Nov 15, 2024
1745fd8
Refactor ilm to image_generation
plaguss Nov 18, 2024
943d922
Add tests for openai image generation
plaguss Nov 18, 2024
40446f2
Add base ImageGenerationModel classes to improve maintainability
plaguss Nov 18, 2024
43964f7
Add tests for inference endpoints
plaguss Nov 18, 2024
0761894
Fix class names and types
plaguss Nov 18, 2024
b833e38
Update docs with image generation models
plaguss Nov 18, 2024
b9dd6d8
Update the distiset docs to include the new method
plaguss Nov 18, 2024
6f4846d
Update examples with the new behaviour
plaguss Nov 18, 2024
665a156
Create module for common operations on images
plaguss Nov 18, 2024
e737d04
Update image generation task and distiset to transform the images bef…
plaguss Nov 18, 2024
e08a90a
Add tests for the distiset and image generation task
plaguss Nov 18, 2024
a84a2a3
Define image generation models from zero
plaguss Nov 18, 2024
7a601e1
Fixed openai tests mocking call to requests.get
plaguss Nov 19, 2024
66658d8
Merge with develop
plaguss Nov 19, 2024
26fe6e2
Merge and fix conflict
plaguss Nov 19, 2024
6a3d279
Make image_to_str more general
plaguss Nov 19, 2024
9c1ffc1
Create base ImageTask to deal with ImageGenerationModels
plaguss Nov 19, 2024
aa6f9f5
Replace with image_to_str function
plaguss Nov 19, 2024
94360f6
Move import
plaguss Nov 19, 2024
10538ff
Fix MRO in class inheritance
plaguss Nov 19, 2024
e98b307
Fix example script
plaguss Nov 19, 2024
8974ff0
Fix optional key not found in runtime parameters
plaguss Nov 20, 2024
341622d
Update examples and simplify process method
plaguss Nov 20, 2024
c46fdc5
Add ImageTask to the API reference
plaguss Nov 20, 2024
e9e6790
Update image task docs
plaguss Nov 20, 2024
ffb3c3b
Some types ignores
gabrielmbmb Jan 13, 2025
99467be
Merge and fix conflict
plaguss Jan 13, 2025
e2173e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
f50aca1
Update docs/api/models/image_generation/index.md
plaguss Jan 13, 2025
8033275
Update docs/api/task/image_task.md
plaguss Jan 13, 2025
62e2c85
Update docs/sections/how_to_guides/advanced/distiset.md
plaguss Jan 13, 2025
e5442e3
Create a new runtime parameters mixin specific to model related funct…
plaguss Jan 13, 2025
29b03bc
typing module refactor to keep all type related info on its own module
plaguss Jan 14, 2025
00e9d83
Create new base client module to store common functionality across di…
plaguss Jan 14, 2025
a04c2f2
BIG refactor due to typing
plaguss Jan 14, 2025
70ce54b
Refactor tests
plaguss Jan 14, 2025
4d7945d
Missing refactor
plaguss Jan 14, 2025
d15e02b
Refactor typing in docs
plaguss Jan 15, 2025
ec2d765
Update src/distilabel/models/image_generation/base.py
plaguss Jan 15, 2025
7debafd
Remove unused check for optional parameters in runtime parameters
plaguss Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api/models/image_generation/image_generation_gallery.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# ImageGenerationModel Gallery

This section contains the existing [`ImageGenerationModel`][distilabel.models.image_generation] subclasses implemented in `distilabel`.

::: distilabel.models.image_generation
options:
filters:
- "!^ImageGenerationModel$"
- "!^AsyngImageGenerationModel$"
- "!typing"
7 changes: 7 additions & 0 deletions docs/api/models/image_generation/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ImageGenerationModel

This section contains the API reference for the `distilabel` LLMs, both for the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] synchronous implementation, and for the [`AsyncImageGenerationModel`][distilabel.models.image_generation.AsyncImageGenerationModel] asynchronous one.

For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - ImageGenerationModel](../../../sections/how_to_guides/basic/task/image_task.md).

::: distilabel.models.image_generation.base
7 changes: 7 additions & 0 deletions docs/api/task/image_task.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ImageTask

This section contains the API reference for the `distilabel` generator tasks.

For more information on how the [`ImageTask`][distilabel.steps.tasks.ImageTask] works and see some examples, check the [Tutorial - Task - ImageTask](../../sections/how_to_guides/basic/task/generator_task.md) page.

::: distilabel.steps.tasks.base.ImageTask
1 change: 1 addition & 0 deletions docs/api/task/task_gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ This section contains the existing [`Task`][distilabel.steps.tasks.Task] subclas
- "!Task"
- "!_Task"
- "!GeneratorTask"
- "!ImageTask"
- "!ChatType"
- "!typing"
27 changes: 27 additions & 0 deletions docs/sections/how_to_guides/advanced/distiset.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,33 @@ class MagpieGenerator(GeneratorTask, MagpieBase):

The `Citations` section can include any number of bibtex references. To define them, you can add as much elements as needed just like in the example: each citation will be a block of the form: ` ```@misc{...}``` `. This information will be automatically used in the README of your `Distiset` if you decide to call `distiset.push_to_hub`. Alternatively, if the `Citations` is not found, but in the `References` there are found any urls pointing to `https://arxiv.org/`, we will try to obtain the `Bibtex` equivalent automatically. This way, Hugging Face can automatically track the paper for you and it's easier to find other datasets citing the same paper, or directly visiting the paper page.

#### Image Datasets

!!! info "Keep reading if you are interested in Image datasets"

The `Distiset` object has a new method `transform_columns_to_image` specifically to transform the images to `PIL.Image.Image` before pushing the dataset to the hugging face hub.

Since version `1.5.0` we have the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task that is able to generate images from text. By default, all the process will work internally with a string representation for the images. This is done for simplicity while processing. But to take advantage of the huggingface hub functionalities if the dataset generate is going to be stored there, a proper Image object may be preferable, so we can see the images in the dataset viewer for example. Let's take a look at the following pipeline extracted from "examples/image_generation.py" at the root of the repository to see how we can do it:

```diff
# Assume all the imports are already done, we are only interested
with Pipeline(name="image_generation_pipeline") as pipeline:
img_generation = ImageGeneration(
name="flux_schnell",
llm=igm,
InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")
)
...

if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
+ distiset = distiset.transform_columns_to_image("image")
distiset.push_to_hub(...)
```

After calling [`transform_columns_to_image`][distilabel.distiset.Distiset.transform_columns_to_image] on the image columns we may have generated (in this case we only want to transform the `image` column, but a list can be passed). This will apply to any leaf nodes we have in the pipeline, meaning if we have different subsets, the "image" column will be found in all of them, or we can pass a list of columns.

### Save and load from disk

Take into account that these methods work as `datasets.load_from_disk` and `datasets.Dataset.save_to_disk` so the arguments are directly passed to those methods. This means you can also make use of `storage_options` argument to save your [`Distiset`][distilabel.distiset.Distiset] in your cloud provider, including the distilabel artifacts (`pipeline.yaml`, `pipeline.log` and the `README.md` with the dataset card). You can read more in `datasets` documentation [here](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets).
Expand Down
104 changes: 104 additions & 0 deletions docs/sections/how_to_guides/basic/task/image_task.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# ImageTask to work with Image Generation Models

## Working with ImageTasks

The [`ImageTask`][distilabel.steps.tasks.ImageTask] is a custom implementation of a [`Task`][distilabel.steps.tasks.Task] special to deal images. These tasks behave exactly as any other [`Task`][distilabel.steps.tasks.Task], but instead of relying on an [`LLM`][distilabel.models.llms.LLM], they work with a [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel].

!!! info "New in version 1.5.0"
This task is new and is expected to work with Image Generation Models.

These tasks take as attribute an `image_generation_model` instead of `llm` as we would have with the standard `Task`, but everything else remains the same. Let's see an example with [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/imagegeneration/):

```python
from distilabel.steps.tasks import ImageGeneration
from distilabel.models.image_generation import InferenceEndpointsImageGeneration

task = ImageGeneration(
name="image-generation",
image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"),
)
task.load()

next(task.process([{"prompt": "a white siamese cat"}]))
# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}]
```

!!! info "Visualize the image in a notebook"
If you are testing the `ImageGeneration` task in a notebook, you can do the following
to see the rendered image:

```python
from distilabel.models.image_generation.utils import image_from_str

result = next(task.process([{"prompt": "a white siamese cat"}]))
image_from_str(result[0]["image"]) # Returns a `PIL.Image.Image` that renders directly
```

!!! tip "Running ImageGeneration in a Pipeline"
This transformation between image as string and as PIL object can be done for the whole dataset if running a pipeline, by calling the method `transform_columns_to_image` on the final distiset and passing the name (or list of names) of the column image.

## Defining custom ImageTasks

We can define a custom generator task by creating a new subclass of the [`ImageTask`][distilabel.steps.tasks.ImageTask] and defining the following:

- `process`: is a method that generates the data based on the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and the `prompt` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`.

- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not.

- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM.

- `format_input`: is a method that receives a dictionary with the input data and returns a *prompt* to be passed to the model.

- `format_output`: is a method that receives the output from the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`.

```python
from typing import TYPE_CHECKING

from distilabel.models.image_generation.utils import image_from_str, image_to_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput


class MyCustomImageTask(ImageTask):
@override
def process(self, offset: int = 0) -> GeneratorOutput:
formatted_inputs = self._format_inputs(inputs)

outputs = self.llm.generate_outputs(
inputs=formatted_inputs,
num_generations=self.num_generations,
**self.llm.get_generation_kwargs(),
)

task_outputs = []
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)
for formatted_output in formatted_outputs:
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
yield task_outputs

@property
def inputs(self) -> "StepColumns":
return ["prompt"]

@property
def outputs(self) -> "StepColumns":
return ["image", "model_name"]

def format_input(self, input: dict[str, any]) -> str:
return input["prompt"]

def format_output(
self, output: Union[str, None], input: dict[str, any]
) -> Dict[str, Any]:
# Extract/generate/modify the image from the output
return {"image": ..., "model_name": self.llm.model_name}
```

!!! Warning
Note the fact that in the `process` method we are not dealing with the `image_generation` attribute but with the `llm`. This is not a bug, but intended, as internally we rename the `image_generation` to `llm` to reuse the code.
108 changes: 108 additions & 0 deletions docs/sections/pipeline_samples/examples/image_generation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
---
hide: toc
---

# Image generation with `distilabel`

Create synthetic images using `distilabel`.

This example shows how distilabel can be used to generate image data, either using [`InferenceEndpointsImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/inferenceendpointsimagegeneration/) or [`OpenAIImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/openaiimagegeneration/), thanks to the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task.


=== "Inference Endpoints - black-forest-labs/FLUX.1-schnell"

```python
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.steps.tasks import ImageGeneration

from datasets import load_dataset

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
ilm = InferenceEndpointsImageGeneration(
model_id="black-forest-labs/FLUX.1-schnell"
)

img_generation = ImageGeneration(
name="flux_schnell",
llm=ilm,
input_mappings={"prompt": "persona"}
)

keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

img_generation >> keep_columns
```

Sample image for the prompt:

> A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.

![image_ie](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-flux-schnell/resolve/main/artifacts/flux_schnell/images/3333f9870feda32a449994017eb72675.jpeg)

=== "OpenAI - dall-e-3"

```python
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import OpenAIImageGeneration
from distilabel.steps.tasks import ImageGeneration

from datasets import load_dataset

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
ilm = OpenAIImageGeneration(
model="dall-e-3",
generation_kwargs={
"size": "1024x1024",
"quality": "standard",
"style": "natural"
}
)

img_generation = ImageGeneration(
name="dalle-3"
llm=ilm,
input_mappings={"prompt": "persona"}
)

keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

img_generation >> keep_columns
```

Sample image for the prompt:

> A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.

![image_oai](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-dall-e-3/resolve/main/artifacts/dalle-3/images/3333f9870feda32a449994017eb72675.jpeg)

!!! success "Save the Distiset as an Image Dataset"

Note the call to `Distiset.transform_columns_to_image`, to have the images uploaded directly as an [`Image dataset`](https://huggingface.co/docs/hub/en/datasets-image):

```python
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
distiset = distiset.transform_columns_to_image("image")
distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")

```

The full pipeline can be run at the following example. Keep in mind, you need to install `pillow` first: `pip install distilabel[vision]`.

??? Run

```python
python examples/image_generation.py
```

```python title="image_generation.py"
--8<-- "examples/image_generation.py"
```
8 changes: 8 additions & 0 deletions docs/sections/pipeline_samples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ hide: toc

[:octicons-arrow-right-24: Example](examples/exam_questions.md)

- __Image generation with distilabel__

---

Generate synthetic images using distilabel.

[:octicons-arrow-right-24: Example](examples/image_generation.md)


</div>

Expand Down
42 changes: 42 additions & 0 deletions examples/image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset

from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")

img_generation = ImageGeneration(
name="flux_schnell",
image_generation_model=igm,
input_mappings={"prompt": "persona"},
)

keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

img_generation >> keep_columns


if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
distiset = distiset.transform_columns_to_image("image")
distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")
5 changes: 5 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ nav:
- Tasks for generating and judging with LLMs:
- "sections/how_to_guides/basic/task/index.md"
- GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md"
- ImageTask: "sections/how_to_guides/basic/task/image_task.md"
- Executing Tasks with LLMs: "sections/how_to_guides/basic/llm/index.md"
- Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md"
- Advanced:
Expand Down Expand Up @@ -218,6 +219,7 @@ nav:
- Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md"
- Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md"
- Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md"
- Image generation with distilabel: "sections/pipeline_samples/examples/image_generation.md"
- API Reference:
- Step:
- "api/step/index.md"
Expand All @@ -242,6 +244,9 @@ nav:
- Embedding:
- "api/models/embedding/index.md"
- Embedding Gallery: "api/models/embedding/embedding_gallery.md"
- ImageGenerationModels:
- "api/models/image_generation/index.md"
- Image Generation Gallery: "api/models/image_generation/image_generation_gallery.md"
- Pipeline:
- "api/pipeline/index.md"
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ text-clustering = [
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
vision = ["Pillow >= 10.3.0"] # To work with images.

# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
Expand Down
Loading