Skip to content

Commit a5b3407

Browse files
committed
Loosen class requirements in from_preset
1 parent 4bed9ae commit a5b3407

File tree

9 files changed

+194
-188
lines changed

9 files changed

+194
-188
lines changed

keras_hub/src/layers/preprocessing/audio_converter.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
PreprocessingLayer,
44
)
55
from keras_hub.src.utils.preset_utils import builtin_presets
6-
from keras_hub.src.utils.preset_utils import find_subclass
76
from keras_hub.src.utils.preset_utils import get_preset_loader
87
from keras_hub.src.utils.preset_utils import get_preset_saver
98
from keras_hub.src.utils.python_utils import classproperty
@@ -89,10 +88,7 @@ class like `keras_hub.models.AudioConverter.from_preset()`, or from a
8988
```
9089
"""
9190
loader = get_preset_loader(preset)
92-
backbone_cls = loader.check_backbone_class()
93-
if cls.backbone_cls != backbone_cls:
94-
cls = find_subclass(preset, cls, backbone_cls)
95-
return loader.load_audio_converter(cls, **kwargs)
91+
return loader.load_audio_converter(cls=cls, kwargs=kwargs)
9692

9793
def save_to_preset(self, preset_dir):
9894
"""Save audio converter to a preset directory.

keras_hub/src/layers/preprocessing/image_converter.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212
from keras_hub.src.utils.keras_utils import standardize_data_format
1313
from keras_hub.src.utils.preset_utils import builtin_presets
14-
from keras_hub.src.utils.preset_utils import find_subclass
1514
from keras_hub.src.utils.preset_utils import get_preset_loader
1615
from keras_hub.src.utils.preset_utils import get_preset_saver
1716
from keras_hub.src.utils.python_utils import classproperty
@@ -380,10 +379,7 @@ def from_preset(
380379
```
381380
"""
382381
loader = get_preset_loader(preset)
383-
backbone_cls = loader.check_backbone_class()
384-
if cls.backbone_cls != backbone_cls:
385-
cls = find_subclass(preset, cls, backbone_cls)
386-
return loader.load_image_converter(cls, **kwargs)
382+
return loader.load_image_converter(cls=cls, kwargs=kwargs)
387383

388384
def save_to_preset(self, preset_dir):
389385
"""Save image converter to a preset directory.

keras_hub/src/models/backbone.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,9 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
168168
```
169169
"""
170170
loader = get_preset_loader(preset)
171-
backbone_cls = loader.check_backbone_class()
172-
if not issubclass(backbone_cls, cls):
173-
raise ValueError(
174-
f"Saved preset has type `{backbone_cls.__name__}` which is not "
175-
f"a subclass of calling class `{cls.__name__}`. Call "
176-
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
177-
)
178-
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
171+
return loader.load_backbone(
172+
cls=cls, load_weights=load_weights, kwargs=kwargs
173+
)
179174

180175
def save_to_preset(self, preset_dir, max_shard_size=10):
181176
"""Save backbone to a preset directory.

keras_hub/src/models/preprocessor.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
)
77
from keras_hub.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
88
from keras_hub.src.utils.preset_utils import builtin_presets
9-
from keras_hub.src.utils.preset_utils import find_subclass
109
from keras_hub.src.utils.preset_utils import get_preset_loader
1110
from keras_hub.src.utils.preset_utils import get_preset_saver
1211
from keras_hub.src.utils.python_utils import classproperty
@@ -171,43 +170,38 @@ def from_preset(
171170
)
172171
```
173172
"""
174-
if cls == Preprocessor:
173+
if cls is Preprocessor:
175174
raise ValueError(
176175
"Do not call `Preprocessor.from_preset()` directly. Instead "
177176
"choose a particular task preprocessing class, e.g. "
178177
"`keras_hub.models.TextClassifierPreprocessor.from_preset()`."
179178
)
180179

181180
loader = get_preset_loader(preset)
182-
backbone_cls = loader.check_backbone_class()
183-
# Detect the correct subclass if we need to.
184-
if cls.backbone_cls != backbone_cls:
185-
cls = find_subclass(preset, cls, backbone_cls)
186-
return loader.load_preprocessor(cls, config_file, **kwargs)
181+
return loader.load_preprocessor(
182+
cls=cls, config_file=config_file, kwargs=kwargs
183+
)
187184

188185
@classmethod
189-
def _add_missing_kwargs(cls, loader, kwargs):
190-
"""Fill in required kwargs when loading from preset.
191-
192-
This is a private method hit when loading a preprocessing layer that
193-
was not directly saved in the preset. This method should fill in
194-
all required kwargs required to call the class constructor. For almost,
195-
all preprocessors, the only required args are `tokenizer`,
196-
`image_converter`, and `audio_converter`, but this can be overridden,
197-
e.g. for a preprocessor with multiple tokenizers for different
198-
encoders.
186+
def _from_defaults(cls, loader, kwargs):
187+
"""Load a preprocessor from default values.
188+
189+
This is a private method hit for loading a preprocessing layer that was
190+
not directly saved in the preset. Usually this means loading a
191+
tokenizer, image_converter and/or audio_converter and calling the
192+
constructor. But this can be overridden by subclasses as needed.
199193
"""
194+
defaults = {}
195+
# Allow loading any tokenizer, image_converter or audio_converter config
196+
# we find on disk. We allow mixing a matching tokenizers and
197+
# preprocessing layers (though this is usually not a good idea).
200198
if "tokenizer" not in kwargs and cls.tokenizer_cls:
201-
kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls)
199+
defaults["tokenizer"] = loader.load_tokenizer()
202200
if "audio_converter" not in kwargs and cls.audio_converter_cls:
203-
kwargs["audio_converter"] = loader.load_audio_converter(
204-
cls.audio_converter_cls
205-
)
201+
defaults["audio_converter"] = loader.load_audio_converter()
206202
if "image_converter" not in kwargs and cls.image_converter_cls:
207-
kwargs["image_converter"] = loader.load_image_converter(
208-
cls.image_converter_cls
209-
)
210-
return kwargs
203+
defaults["image_converter"] = loader.load_image_converter()
204+
return cls(**{**defaults, **kwargs})
211205

212206
def load_preset_assets(self, preset):
213207
"""Load all static assets needed by the preprocessing layer.

keras_hub/src/models/task.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
from keras_hub.src.api_export import keras_hub_export
77
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
88
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
9-
from keras_hub.src.models.backbone import Backbone
109
from keras_hub.src.models.preprocessor import Preprocessor
1110
from keras_hub.src.tokenizers.tokenizer import Tokenizer
1211
from keras_hub.src.utils.keras_utils import print_msg
1312
from keras_hub.src.utils.pipeline_model import PipelineModel
1413
from keras_hub.src.utils.preset_utils import builtin_presets
15-
from keras_hub.src.utils.preset_utils import find_subclass
1614
from keras_hub.src.utils.preset_utils import get_preset_loader
1715
from keras_hub.src.utils.preset_utils import get_preset_saver
1816
from keras_hub.src.utils.python_utils import classproperty
@@ -175,27 +173,38 @@ def from_preset(
175173
)
176174
```
177175
"""
178-
if cls == Task:
176+
if cls is Task:
179177
raise ValueError(
180178
"Do not call `Task.from_preset()` directly. Instead call a "
181179
"particular task class, e.g. "
182180
"`keras_hub.models.TextClassifier.from_preset()`."
183181
)
184182

185183
loader = get_preset_loader(preset)
186-
backbone_cls = loader.check_backbone_class()
187-
# Detect the correct subclass if we need to.
188-
if (
189-
issubclass(backbone_cls, Backbone)
190-
and cls.backbone_cls != backbone_cls
191-
):
192-
cls = find_subclass(preset, cls, backbone_cls)
193-
# Specifically for classifiers, we never load task weights if
194-
# num_classes is supplied. We handle this in the task base class because
195-
# it is the same logic for classifiers regardless of modality (text,
196-
# images, audio).
197-
load_task_weights = "num_classes" not in kwargs
198-
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
184+
return loader.load_task(
185+
cls=cls, load_weights=load_weights, kwargs=kwargs
186+
)
187+
188+
@classmethod
189+
def _from_defaults(cls, loader, load_weights, kwargs, backbone_kwargs):
190+
"""Load a task from default values.
191+
192+
This is a private method hit for loading a task layer that was
193+
not directly saved in the preset. Usually this means loading a backbone
194+
and preprocessor and calling the constructor. But this can be overridden
195+
by subclasses as needed.
196+
"""
197+
defaults = {}
198+
if "backbone" not in kwargs:
199+
defaults["backbone"] = loader.load_backbone(
200+
load_weights=load_weights, kwargs=backbone_kwargs
201+
)
202+
if "preprocessor" not in kwargs and cls.preprocessor_cls:
203+
# Only load the "matching" preprocessor class for a task class.
204+
defaults["preprocessor"] = loader.load_preprocessor(
205+
cls=cls.preprocessor_cls
206+
)
207+
return cls(**{**defaults, **kwargs})
199208

200209
def load_task_weights(self, filepath):
201210
"""Load only the tasks specific weights not in the backbone."""

keras_hub/src/tokenizers/tokenizer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from keras_hub.src.utils.preset_utils import ASSET_DIR
88
from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
99
from keras_hub.src.utils.preset_utils import builtin_presets
10-
from keras_hub.src.utils.preset_utils import find_subclass
1110
from keras_hub.src.utils.preset_utils import get_file
1211
from keras_hub.src.utils.preset_utils import get_preset_loader
1312
from keras_hub.src.utils.preset_utils import get_preset_saver
@@ -257,7 +256,6 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
257256
```
258257
"""
259258
loader = get_preset_loader(preset)
260-
backbone_cls = loader.check_backbone_class()
261-
if cls.backbone_cls != backbone_cls:
262-
cls = find_subclass(preset, cls, backbone_cls)
263-
return loader.load_tokenizer(cls, config_file, **kwargs)
259+
return loader.load_tokenizer(
260+
cls=cls, config_file=config_file, kwargs=kwargs
261+
)

0 commit comments

Comments
 (0)