Skip to content

Commit 0975982

Browse files
authored
Add a TextToImagePreprocessor base class (#2181)
Minor bit of bookeeping, adds a `TextToImagePreprocessor` base class. Not moving any common functionality here yet until we have at least two image generation models, but for now this will allow the auto class functionality to work.
1 parent a70b87d commit 0975982

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

keras_hub/api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@
369369
TextClassifierPreprocessor,
370370
)
371371
from keras_hub.src.models.text_to_image import TextToImage
372+
from keras_hub.src.models.text_to_image_preprocessor import (
373+
TextToImagePreprocessor,
374+
)
372375
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
373376
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
374377
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from keras import layers
33

44
from keras_hub.src.api_export import keras_hub_export
5-
from keras_hub.src.models.preprocessor import Preprocessor
65
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
76
StableDiffusion3Backbone,
87
)
8+
from keras_hub.src.models.text_to_image_preprocessor import (
9+
TextToImagePreprocessor,
10+
)
911

1012

1113
@keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor")
12-
class StableDiffusion3TextToImagePreprocessor(Preprocessor):
14+
class StableDiffusion3TextToImagePreprocessor(TextToImagePreprocessor):
1315
"""Stable Diffusion 3 text-to-image model preprocessor.
1416
1517
This preprocessing layer is meant for use with
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.preprocessor import Preprocessor
3+
4+
5+
@keras_hub_export("keras_hub.models.TextToImagePreprocessor")
6+
class TextToImagePreprocessor(Preprocessor):
7+
"""Base class for text to image preprocessing layers.
8+
9+
`TextToImagePreprocessor` tasks wrap a `keras_hub.tokenizer.Tokenizer` to
10+
create a preprocessing layer for text to image tasks. It is intended to be
11+
paired with a `keras_hub.models.TextToImage` task.
12+
13+
The exact specifics of this layer will vary depending on the subclass
14+
implementation per model architecture. Generally, it will take text input,
15+
and tokenize, then pad/truncate so it is ready to be fed to a image
16+
generation model (e.g. a diffusion model).
17+
18+
Examples.
19+
```python
20+
preprocessor = keras_hub.models.TextToImagePreprocessor.from_preset(
21+
"stable_diffusion_3_medium",
22+
sequence_length=256, # Optional.
23+
)
24+
25+
# Tokenize and pad/truncate a single sentence.
26+
x = "The quick brown fox jumped."
27+
x = preprocessor(x)
28+
29+
# Tokenize and pad/truncate a batch of sentences.
30+
x = ["The quick brown fox jumped."]
31+
x = preprocessor(x)
32+
```
33+
"""
34+
35+
pass
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
3+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
4+
StableDiffusion3TextToImagePreprocessor,
5+
)
6+
from keras_hub.src.models.text_to_image_preprocessor import (
7+
TextToImagePreprocessor,
8+
)
9+
from keras_hub.src.tests.test_case import TestCase
10+
11+
12+
class TestTextToImagePreprocessor(TestCase):
13+
@pytest.mark.large
14+
def test_from_preset(self):
15+
self.assertIsInstance(
16+
TextToImagePreprocessor.from_preset("stable_diffusion_3_medium"),
17+
StableDiffusion3TextToImagePreprocessor,
18+
)
19+
self.assertIsInstance(
20+
StableDiffusion3TextToImagePreprocessor.from_preset(
21+
"stable_diffusion_3_medium"
22+
),
23+
StableDiffusion3TextToImagePreprocessor,
24+
)
25+
26+
@pytest.mark.large
27+
def test_from_preset_errors(self):
28+
with self.assertRaises(ValueError):
29+
# No loading on an incorrect class.
30+
StableDiffusion3TextToImagePreprocessor.from_preset("gpt2_base_en")
31+
with self.assertRaises(ValueError):
32+
# No loading on a non-keras model.
33+
StableDiffusion3TextToImagePreprocessor.from_preset(
34+
"hf://spacy/en_core_web_sm"
35+
)

0 commit comments

Comments
 (0)