Skip to content

Commit 0c04abe

Browse files
authored
Class detection works for huggingface checkpoints (#1800)
* Class detection works for huggingface checkpoints This was a tricky one to fix that involved some large refactoring to our preset loading routines. Originally the intent was that `from_preset()` was a easily readable bunch of lower-level Keras calls. With the arrival of transformers conversions, and soon timm conversions, I think that goal is no longer super realistic. Instead I added a loader interface, with default implementations off `load_task` and `load_preprocessor`. Every format we support directly converting from has to support at a minimum... - Detecting the backbone class. - Loading the backbone class. One consequence of this work is that every class with a `from_preset` constructor needs to reference the `backbone_cls` they match with. I think this will be a more stable way to handle our "auto class" like functionality as we venture further towards multi-modal models * Address comments
1 parent fbc1335 commit 0c04abe

File tree

75 files changed

+623
-585
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+623
-585
lines changed

keras_nlp/src/models/albert/albert_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
1919
MultiSegmentPacker,
2020
)
21+
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
2122
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
2223
from keras_nlp.src.models.preprocessor import Preprocessor
2324
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -144,6 +145,7 @@ class AlbertPreprocessor(Preprocessor):
144145
```
145146
"""
146147

148+
backbone_cls = AlbertBackbone
147149
tokenizer_cls = AlbertTokenizer
148150

149151
def __init__(

keras_nlp/src/models/albert/albert_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.src.api_export import keras_nlp_export
16+
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
1617
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
1718
SentencePieceTokenizer,
1819
)
@@ -84,6 +85,8 @@ class AlbertTokenizer(SentencePieceTokenizer):
8485
```
8586
"""
8687

88+
backbone_cls = AlbertBackbone
89+
8790
def __init__(self, proto, **kwargs):
8891
self.cls_token = "[CLS]"
8992
self.sep_token = "[SEP]"

keras_nlp/src/models/backbone.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,12 @@
2020
from keras_nlp.src.utils.keras_utils import assert_quantization_support
2121
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
2222
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
23-
from keras_nlp.src.utils.preset_utils import check_config_class
24-
from keras_nlp.src.utils.preset_utils import check_format
25-
from keras_nlp.src.utils.preset_utils import get_file
26-
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
23+
from keras_nlp.src.utils.preset_utils import get_preset_loader
2724
from keras_nlp.src.utils.preset_utils import list_presets
2825
from keras_nlp.src.utils.preset_utils import list_subclasses
29-
from keras_nlp.src.utils.preset_utils import load_serialized_object
3026
from keras_nlp.src.utils.preset_utils import save_metadata
3127
from keras_nlp.src.utils.preset_utils import save_serialized_object
3228
from keras_nlp.src.utils.python_utils import classproperty
33-
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone
3429

3530

3631
@keras_nlp_export("keras_nlp.models.Backbone")
@@ -200,25 +195,15 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
200195
)
201196
```
202197
"""
203-
format = check_format(preset)
204-
205-
if format == "transformers":
206-
return load_transformers_backbone(cls, preset, load_weights)
207-
208-
preset_cls = check_config_class(preset)
209-
if not issubclass(preset_cls, cls):
198+
loader = get_preset_loader(preset)
199+
backbone_cls = loader.check_backbone_class()
200+
if not issubclass(backbone_cls, cls):
210201
raise ValueError(
211-
f"Preset has type `{preset_cls.__name__}` which is not a "
202+
f"Saved preset has type `{backbone_cls.__name__}` which is not "
212203
f"a subclass of calling class `{cls.__name__}`. Call "
213-
f"`from_preset` directly on `{preset_cls.__name__}` instead."
204+
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
214205
)
215-
216-
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
217-
if load_weights:
218-
jax_memory_cleanup(backbone)
219-
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
220-
221-
return backbone
206+
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
222207

223208
def save_to_preset(self, preset_dir):
224209
"""Save backbone to a preset directory.

keras_nlp/src/models/backbone_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from keras_nlp.src.utils.preset_utils import METADATA_FILE
2525
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
2626
from keras_nlp.src.utils.preset_utils import check_config_class
27-
from keras_nlp.src.utils.preset_utils import load_config
27+
from keras_nlp.src.utils.preset_utils import load_json
2828

2929

3030
class TestBackbone(TestCase):
@@ -68,7 +68,7 @@ def test_from_preset_errors(self):
6868
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
6969
with self.assertRaises(ValueError):
7070
# No loading on a non-keras model.
71-
Backbone.from_preset("hf://google-bert/bert-base-uncased")
71+
Backbone.from_preset("hf://spacy/en_core_web_sm")
7272

7373
@pytest.mark.large
7474
def test_save_to_preset(self):
@@ -84,12 +84,12 @@ def test_save_to_preset(self):
8484
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))
8585

8686
# Check the backbone config (`config.json`).
87-
backbone_config = load_config(save_dir, CONFIG_FILE)
87+
backbone_config = load_json(save_dir, CONFIG_FILE)
8888
self.assertTrue("build_config" not in backbone_config)
8989
self.assertTrue("compile_config" not in backbone_config)
9090

9191
# Try config class.
92-
self.assertEqual(BertBackbone, check_config_class(save_dir))
92+
self.assertEqual(BertBackbone, check_config_class(backbone_config))
9393

9494
# Try loading the model from preset directory.
9595
restored_backbone = Backbone.from_preset(save_dir)

keras_nlp/src/models/bart/bart_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from keras_nlp.src.api_export import keras_nlp_export
1919
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
20+
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
2021
from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer
2122
from keras_nlp.src.models.preprocessor import Preprocessor
2223
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -127,6 +128,7 @@ class BartPreprocessor(Preprocessor):
127128
```
128129
"""
129130

131+
backbone_cls = BartBackbone
130132
tokenizer_cls = BartTokenizer
131133

132134
def __init__(

keras_nlp/src/models/bart/bart_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
1718
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
1819

1920

@@ -73,6 +74,8 @@ class BartTokenizer(BytePairTokenizer):
7374
```
7475
"""
7576

77+
backbone_cls = BartBackbone
78+
7679
def __init__(
7780
self,
7881
vocabulary=None,

keras_nlp/src/models/bert/bert_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
1919
MultiSegmentPacker,
2020
)
21+
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
2122
from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer
2223
from keras_nlp.src.models.preprocessor import Preprocessor
2324
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -122,6 +123,7 @@ class BertPreprocessor(Preprocessor):
122123
```
123124
"""
124125

126+
backbone_cls = BertBackbone
125127
tokenizer_cls = BertTokenizer
126128

127129
def __init__(

keras_nlp/src/models/bert/bert_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.src.api_export import keras_nlp_export
16+
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
1617
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
1718

1819

@@ -68,6 +69,8 @@ class BertTokenizer(WordPieceTokenizer):
6869
```
6970
"""
7071

72+
backbone_cls = BertBackbone
73+
7174
def __init__(
7275
self,
7376
vocabulary=None,

keras_nlp/src/models/bloom/bloom_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from keras_nlp.src.api_export import keras_nlp_export
1919
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
20+
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
2021
from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer
2122
from keras_nlp.src.models.preprocessor import Preprocessor
2223
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -103,6 +104,7 @@ class BloomPreprocessor(Preprocessor):
103104
```
104105
"""
105106

107+
backbone_cls = BloomBackbone
106108
tokenizer_cls = BloomTokenizer
107109

108110
def __init__(

keras_nlp/src/models/bloom/bloom_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
1718
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
1819

1920

@@ -65,6 +66,8 @@ class BloomTokenizer(BytePairTokenizer):
6566
```
6667
"""
6768

69+
backbone_cls = BloomBackbone
70+
6871
def __init__(
6972
self,
7073
vocabulary=None,

keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
2020
MultiSegmentPacker,
2121
)
22+
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
23+
DebertaV3Backbone,
24+
)
2225
from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import (
2326
DebertaV3Tokenizer,
2427
)
@@ -145,6 +148,7 @@ class DebertaV3Preprocessor(Preprocessor):
145148
```
146149
"""
147150

151+
backbone_cls = DebertaV3Backbone
148152
tokenizer_cls = DebertaV3Tokenizer
149153

150154
def __init__(

keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
18+
DebertaV3Backbone,
19+
)
1720
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
1821
SentencePieceTokenizer,
1922
)
@@ -94,6 +97,8 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
9497
```
9598
"""
9699

100+
backbone_cls = DebertaV3Backbone
101+
97102
def __init__(self, proto, **kwargs):
98103
self.cls_token = "[CLS]"
99104
self.sep_token = "[SEP]"

keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
2020
MultiSegmentPacker,
2121
)
22+
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
23+
DistilBertBackbone,
24+
)
2225
from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import (
2326
DistilBertTokenizer,
2427
)
@@ -114,6 +117,7 @@ class DistilBertPreprocessor(Preprocessor):
114117
```
115118
"""
116119

120+
backbone_cls = DistilBertBackbone
117121
tokenizer_cls = DistilBertTokenizer
118122

119123
def __init__(

keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
18+
DistilBertBackbone,
19+
)
1720
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
1821

1922

@@ -70,6 +73,8 @@ class DistilBertTokenizer(WordPieceTokenizer):
7073
```
7174
"""
7275

76+
backbone_cls = DistilBertBackbone
77+
7378
def __init__(
7479
self,
7580
vocabulary,

keras_nlp/src/models/electra/electra_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
1919
MultiSegmentPacker,
2020
)
21+
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
2122
from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer
2223
from keras_nlp.src.models.preprocessor import Preprocessor
2324
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -111,6 +112,7 @@ class ElectraPreprocessor(Preprocessor):
111112
```
112113
"""
113114

115+
backbone_cls = ElectraBackbone
114116
tokenizer_cls = ElectraTokenizer
115117

116118
def __init__(

keras_nlp/src/models/electra/electra_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.src.api_export import keras_nlp_export
16+
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
1617
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
1718

1819

@@ -60,6 +61,8 @@ class ElectraTokenizer(WordPieceTokenizer):
6061
```
6162
"""
6263

64+
backbone_cls = ElectraBackbone
65+
6366
def __init__(
6467
self,
6568
vocabulary,

keras_nlp/src/models/f_net/f_net_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
2020
MultiSegmentPacker,
2121
)
22+
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
2223
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
2324
from keras_nlp.src.models.preprocessor import Preprocessor
2425
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -116,6 +117,7 @@ class FNetPreprocessor(Preprocessor):
116117
```
117118
"""
118119

120+
backbone_cls = FNetBackbone
119121
tokenizer_cls = FNetTokenizer
120122

121123
def __init__(

keras_nlp/src/models/f_net/f_net_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
1718
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
1819
SentencePieceTokenizer,
1920
)
@@ -61,6 +62,8 @@ class FNetTokenizer(SentencePieceTokenizer):
6162
```
6263
"""
6364

65+
backbone_cls = FNetBackbone
66+
6467
def __init__(self, proto, **kwargs):
6568
self.cls_token = "[CLS]"
6669
self.sep_token = "[SEP]"

keras_nlp/src/models/falcon/falcon_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from keras_nlp.src.api_export import keras_nlp_export
1919
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
20+
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
2021
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
2122
from keras_nlp.src.models.preprocessor import Preprocessor
2223
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
@@ -105,6 +106,7 @@ class FalconPreprocessor(Preprocessor):
105106
```
106107
"""
107108

109+
backbone_cls = FalconBackbone
108110
tokenizer_cls = FalconTokenizer
109111

110112
def __init__(

keras_nlp/src/models/falcon/falcon_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from keras_nlp.src.api_export import keras_nlp_export
17+
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
1718
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
1819

1920

@@ -65,6 +66,8 @@ class FalconTokenizer(BytePairTokenizer):
6566
```
6667
"""
6768

69+
backbone_cls = FalconBackbone
70+
6871
def __init__(
6972
self,
7073
vocabulary=None,

0 commit comments

Comments
 (0)