Skip to content

Commit 107a059

Browse files
authored
Add DistilBERT Presets (keras-team#479)
* Add DistilBERT presets * Change preset UT * Fix Uts * Add py conversion script * Convert argparse to absl * Add multilingual DistilBERT * Fix vocab siz * Add preprocessor arg * Add correct hash for multi * Add correct hash for multi * Fix PRESET_MAP * Address comments * Rootify relative import * FIx URLs
1 parent cee7d97 commit 107a059

File tree

7 files changed

+800
-7
lines changed

7 files changed

+800
-7
lines changed

keras_nlp/models/distilbert/distilbert_models.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""DistilBERT backbone models."""
1616

17+
import copy
18+
import os
1719

1820
import tensorflow as tf
1921
from tensorflow import keras
@@ -22,6 +24,9 @@
2224
TokenAndPositionEmbedding,
2325
)
2426
from keras_nlp.layers.transformer_encoder import TransformerEncoder
27+
from keras_nlp.models.distilbert.distilbert_presets import backbone_presets
28+
from keras_nlp.utils.python_utils import classproperty
29+
from keras_nlp.utils.python_utils import format_docstring
2530

2631

2732
def distilbert_kernel_initializer(stddev=0.02):
@@ -178,11 +183,67 @@ def get_config(self):
178183
def from_config(cls, config):
179184
return cls(**config)
180185

186+
@classproperty
187+
def presets(cls):
188+
return copy.deepcopy(backbone_presets)
189+
181190
@classmethod
191+
@format_docstring(names=", ".join(backbone_presets))
182192
def from_preset(
183193
cls,
184194
preset,
185195
load_weights=True,
186196
**kwargs,
187197
):
188-
raise NotImplementedError
198+
"""Instantiate DistilBERT model from preset architecture and weights.
199+
200+
Args:
201+
preset: string. Must be one of {{names}}.
202+
load_weights: Whether to load pre-trained weights into model.
203+
Defaults to `True`.
204+
205+
Examples:
206+
```python
207+
input_data = {
208+
"token_ids": tf.random.uniform(
209+
shape=(1, 12), dtype=tf.int64, maxval=model.vocabulary_size
210+
),
211+
"padding_mask": tf.constant(
212+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
213+
),
214+
}
215+
216+
# Load architecture and weights from preset
217+
model = keras_nlp.models.DistilBert.from_preset(
218+
"distilbert_base_uncased_en"
219+
)
220+
output = model(input_data)
221+
222+
# Load randomly initalized model from preset architecture
223+
model = keras_nlp.models.DistilBert.from_preset(
224+
"distilbert_base_uncased_en", load_weights=False
225+
)
226+
output = model(input_data)
227+
```
228+
"""
229+
if preset not in cls.presets:
230+
raise ValueError(
231+
"`preset` must be one of "
232+
f"""{", ".join(cls.presets)}. Received: {preset}."""
233+
)
234+
metadata = cls.presets[preset]
235+
config = metadata["config"]
236+
model = cls.from_config({**config, **kwargs})
237+
238+
if not load_weights:
239+
return model
240+
241+
weights = keras.utils.get_file(
242+
"model.h5",
243+
metadata["weights_url"],
244+
cache_subdir=os.path.join("models", preset),
245+
file_hash=metadata["weights_hash"],
246+
)
247+
248+
model.load_weights(weights)
249+
return model

keras_nlp/models/distilbert/distilbert_preprocessing.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
# limitations under the License.
1414
"""DistilBERT preprocessing layers."""
1515

16+
import copy
17+
import os
18+
1619
from tensorflow import keras
1720

1821
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
22+
from keras_nlp.models.distilbert.distilbert_presets import backbone_presets
1923
from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer
2024
from keras_nlp.utils.python_utils import classproperty
25+
from keras_nlp.utils.python_utils import format_docstring
2126

2227

2328
@keras.utils.register_keras_serializable(package="keras_nlp")
@@ -104,15 +109,56 @@ def __init__(
104109

105110
@classproperty
106111
def presets(cls):
107-
raise NotImplementedError
112+
return copy.deepcopy(backbone_presets)
108113

109114
@classmethod
115+
@format_docstring(names=", ".join(backbone_presets))
110116
def from_preset(
111117
cls,
112118
preset,
113119
**kwargs,
114120
):
115-
raise NotImplementedError
121+
"""Instantiate a DistilBERT tokenizer from preset vocabulary.
122+
123+
Args:
124+
preset: string. Must be one of {{names}}.
125+
126+
Examples:
127+
```python
128+
# Load a preset tokenizer.
129+
tokenizer = keras_nlp.models.DistilBertTokenizer.from_preset(
130+
"distilbert_base_uncased_en",
131+
)
132+
133+
# Tokenize some input.
134+
tokenizer("The quick brown fox tripped.")
135+
136+
# Detokenize some input.
137+
tokenizer.detokenize([5, 6, 7, 8, 9])
138+
```
139+
"""
140+
if preset not in cls.presets:
141+
raise ValueError(
142+
"`preset` must be one of "
143+
f"""{", ".join(cls.presets)}. Received: {preset}."""
144+
)
145+
metadata = cls.presets[preset]
146+
147+
vocabulary = keras.utils.get_file(
148+
"vocab.txt",
149+
metadata["vocabulary_url"],
150+
cache_subdir=os.path.join("models", preset),
151+
file_hash=metadata["vocabulary_hash"],
152+
)
153+
154+
config = metadata["preprocessor_config"]
155+
config.update(
156+
{
157+
"vocabulary": vocabulary,
158+
},
159+
)
160+
161+
return cls.from_config({**config, **kwargs})
116162

117163

118164
@keras.utils.register_keras_serializable(package="keras_nlp")
@@ -238,14 +284,78 @@ def call(self, inputs):
238284

239285
@classproperty
240286
def presets(cls):
241-
raise NotImplementedError
287+
return copy.deepcopy(backbone_presets)
242288

243289
@classmethod
290+
@format_docstring(names=", ".join(backbone_presets))
244291
def from_preset(
245292
cls,
246293
preset,
247294
sequence_length=None,
248295
truncate="round_robin",
249296
**kwargs,
250297
):
251-
raise NotImplementedError
298+
"""Instantiate DistilBERT preprocessor from preset architecture.
299+
300+
Args:
301+
preset: string. Must be one of {{names}}.
302+
sequence_length: int, optional. The length of the packed inputs.
303+
Must be equal to or smaller than the `max_sequence_length` of
304+
the preset. If left as default, the `max_sequence_length` of
305+
the preset will be used.
306+
truncate: string. The algorithm to truncate a list of batched
307+
segments to fit within `sequence_length`. The value can be
308+
either `round_robin` or `waterfall`:
309+
- `"round_robin"`: Available space is assigned one token at
310+
a time in a round-robin fashion to the inputs that still
311+
need some, until the limit is reached.
312+
- `"waterfall"`: The allocation of the budget is done using
313+
a "waterfall" algorithm that allocates quota in a
314+
left-to-right manner and fills up the buckets until we
315+
run out of budget. It supports an arbitrary number of
316+
segments.
317+
318+
Examples:
319+
```python
320+
# Load preprocessor from preset
321+
preprocessor = keras_nlp.models.DistilBertPreprocessor.from_preset(
322+
"distilbert_base_uncased_en",
323+
)
324+
preprocessor("The quick brown fox jumped.")
325+
326+
# Override sequence_length
327+
preprocessor = keras_nlp.models.DistilBertPreprocessor.from_preset(
328+
"distilbert_base_uncased_en",
329+
sequence_length=64
330+
)
331+
preprocessor("The quick brown fox jumped.")
332+
```
333+
"""
334+
if preset not in cls.presets:
335+
raise ValueError(
336+
"`preset` must be one of "
337+
f"""{", ".join(cls.presets)}. Received: {preset}."""
338+
)
339+
340+
tokenizer = DistilBertTokenizer.from_preset(preset)
341+
342+
# Use model's `max_sequence_length` if `sequence_length` unspecified;
343+
# otherwise check that `sequence_length` not too long.
344+
metadata = cls.presets[preset]
345+
max_sequence_length = metadata["config"]["max_sequence_length"]
346+
if sequence_length is not None:
347+
if sequence_length > max_sequence_length:
348+
raise ValueError(
349+
f"`sequence_length` cannot be longer than `{preset}` "
350+
f"preset's `max_sequence_length` of {max_sequence_length}. "
351+
f"Received: {sequence_length}."
352+
)
353+
else:
354+
sequence_length = max_sequence_length
355+
356+
return cls(
357+
tokenizer=tokenizer,
358+
sequence_length=sequence_length,
359+
truncate=truncate,
360+
**kwargs,
361+
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
backbone_presets = {
15+
"distilbert_base_uncased_en": {
16+
"config": {
17+
"vocabulary_size": 30522,
18+
"num_layers": 6,
19+
"num_heads": 12,
20+
"hidden_dim": 768,
21+
"intermediate_dim": 3072,
22+
"dropout": 0.1,
23+
"max_sequence_length": 512,
24+
},
25+
"preprocessor_config": {
26+
"lowercase": True,
27+
},
28+
"description": (
29+
"Base size of DistilBERT where all input is lowercased. "
30+
"Trained on English Wikipedia + BooksCorpus using BERT as the "
31+
"teacher model."
32+
),
33+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_uncased_en/model.h5",
34+
"weights_hash": "6625a649572e74086d74c46b8d0b0da3",
35+
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_uncased_en/vocab.txt",
36+
"vocabulary_hash": "64800d5d8528ce344256daf115d4965e",
37+
},
38+
"distilbert_base_cased_en": {
39+
"config": {
40+
"vocabulary_size": 28996,
41+
"num_layers": 6,
42+
"num_heads": 12,
43+
"hidden_dim": 768,
44+
"intermediate_dim": 3072,
45+
"dropout": 0.1,
46+
"max_sequence_length": 512,
47+
},
48+
"preprocessor_config": {
49+
"lowercase": False,
50+
},
51+
"description": (
52+
"Base size of DistilBERT where case is maintained. "
53+
"Trained on English Wikipedia + BooksCorpus using BERT as the "
54+
"teacher model."
55+
),
56+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_cased_en/model.h5",
57+
"weights_hash": "fa36aa6865978efbf85a5c8264e5eb57",
58+
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_cased_en/vocab.txt",
59+
"vocabulary_hash": "bb6ca9b42e790e5cd986bbb16444d0e0",
60+
},
61+
"distilbert_base_multi_cased": {
62+
"config": {
63+
"vocabulary_size": 119547,
64+
"num_layers": 6,
65+
"num_heads": 12,
66+
"hidden_dim": 768,
67+
"intermediate_dim": 3072,
68+
"dropout": 0.1,
69+
"max_sequence_length": 512,
70+
},
71+
"preprocessor_config": {
72+
"lowercase": False,
73+
},
74+
"description": (
75+
"Base size of DistilBERT. Trained on Wikipedias of 104 languages "
76+
"using BERT the teacher model."
77+
),
78+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_multi_cased/model.h5",
79+
"weights_hash": "c0f11095e2a6455bd3b1a6d14800a7fa",
80+
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distilbert_base_multi_cased/vocab.txt",
81+
"vocabulary_hash": "d9d865138d17f1958502ed060ecfeeb6",
82+
},
83+
}

0 commit comments

Comments
 (0)