|
13 | 13 | # limitations under the License.
|
14 | 14 | """DistilBERT preprocessing layers."""
|
15 | 15 |
|
| 16 | +import copy |
| 17 | +import os |
| 18 | + |
16 | 19 | from tensorflow import keras
|
17 | 20 |
|
18 | 21 | from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
|
| 22 | +from keras_nlp.models.distilbert.distilbert_presets import backbone_presets |
19 | 23 | from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer
|
20 | 24 | from keras_nlp.utils.python_utils import classproperty
|
| 25 | +from keras_nlp.utils.python_utils import format_docstring |
21 | 26 |
|
22 | 27 |
|
23 | 28 | @keras.utils.register_keras_serializable(package="keras_nlp")
|
@@ -104,15 +109,56 @@ def __init__(
|
104 | 109 |
|
105 | 110 | @classproperty
|
106 | 111 | def presets(cls):
|
107 |
| - raise NotImplementedError |
| 112 | + return copy.deepcopy(backbone_presets) |
108 | 113 |
|
109 | 114 | @classmethod
|
| 115 | + @format_docstring(names=", ".join(backbone_presets)) |
110 | 116 | def from_preset(
|
111 | 117 | cls,
|
112 | 118 | preset,
|
113 | 119 | **kwargs,
|
114 | 120 | ):
|
115 |
| - raise NotImplementedError |
| 121 | + """Instantiate a DistilBERT tokenizer from preset vocabulary. |
| 122 | +
|
| 123 | + Args: |
| 124 | + preset: string. Must be one of {{names}}. |
| 125 | +
|
| 126 | + Examples: |
| 127 | + ```python |
| 128 | + # Load a preset tokenizer. |
| 129 | + tokenizer = keras_nlp.models.DistilBertTokenizer.from_preset( |
| 130 | + "distilbert_base_uncased_en", |
| 131 | + ) |
| 132 | +
|
| 133 | + # Tokenize some input. |
| 134 | + tokenizer("The quick brown fox tripped.") |
| 135 | +
|
| 136 | + # Detokenize some input. |
| 137 | + tokenizer.detokenize([5, 6, 7, 8, 9]) |
| 138 | + ``` |
| 139 | + """ |
| 140 | + if preset not in cls.presets: |
| 141 | + raise ValueError( |
| 142 | + "`preset` must be one of " |
| 143 | + f"""{", ".join(cls.presets)}. Received: {preset}.""" |
| 144 | + ) |
| 145 | + metadata = cls.presets[preset] |
| 146 | + |
| 147 | + vocabulary = keras.utils.get_file( |
| 148 | + "vocab.txt", |
| 149 | + metadata["vocabulary_url"], |
| 150 | + cache_subdir=os.path.join("models", preset), |
| 151 | + file_hash=metadata["vocabulary_hash"], |
| 152 | + ) |
| 153 | + |
| 154 | + config = metadata["preprocessor_config"] |
| 155 | + config.update( |
| 156 | + { |
| 157 | + "vocabulary": vocabulary, |
| 158 | + }, |
| 159 | + ) |
| 160 | + |
| 161 | + return cls.from_config({**config, **kwargs}) |
116 | 162 |
|
117 | 163 |
|
118 | 164 | @keras.utils.register_keras_serializable(package="keras_nlp")
|
@@ -238,14 +284,78 @@ def call(self, inputs):
|
238 | 284 |
|
239 | 285 | @classproperty
|
240 | 286 | def presets(cls):
|
241 |
| - raise NotImplementedError |
| 287 | + return copy.deepcopy(backbone_presets) |
242 | 288 |
|
243 | 289 | @classmethod
|
| 290 | + @format_docstring(names=", ".join(backbone_presets)) |
244 | 291 | def from_preset(
|
245 | 292 | cls,
|
246 | 293 | preset,
|
247 | 294 | sequence_length=None,
|
248 | 295 | truncate="round_robin",
|
249 | 296 | **kwargs,
|
250 | 297 | ):
|
251 |
| - raise NotImplementedError |
| 298 | + """Instantiate DistilBERT preprocessor from preset architecture. |
| 299 | +
|
| 300 | + Args: |
| 301 | + preset: string. Must be one of {{names}}. |
| 302 | + sequence_length: int, optional. The length of the packed inputs. |
| 303 | + Must be equal to or smaller than the `max_sequence_length` of |
| 304 | + the preset. If left as default, the `max_sequence_length` of |
| 305 | + the preset will be used. |
| 306 | + truncate: string. The algorithm to truncate a list of batched |
| 307 | + segments to fit within `sequence_length`. The value can be |
| 308 | + either `round_robin` or `waterfall`: |
| 309 | + - `"round_robin"`: Available space is assigned one token at |
| 310 | + a time in a round-robin fashion to the inputs that still |
| 311 | + need some, until the limit is reached. |
| 312 | + - `"waterfall"`: The allocation of the budget is done using |
| 313 | + a "waterfall" algorithm that allocates quota in a |
| 314 | + left-to-right manner and fills up the buckets until we |
| 315 | + run out of budget. It supports an arbitrary number of |
| 316 | + segments. |
| 317 | +
|
| 318 | + Examples: |
| 319 | + ```python |
| 320 | + # Load preprocessor from preset |
| 321 | + preprocessor = keras_nlp.models.DistilBertPreprocessor.from_preset( |
| 322 | + "distilbert_base_uncased_en", |
| 323 | + ) |
| 324 | + preprocessor("The quick brown fox jumped.") |
| 325 | +
|
| 326 | + # Override sequence_length |
| 327 | + preprocessor = keras_nlp.models.DistilBertPreprocessor.from_preset( |
| 328 | + "distilbert_base_uncased_en", |
| 329 | + sequence_length=64 |
| 330 | + ) |
| 331 | + preprocessor("The quick brown fox jumped.") |
| 332 | + ``` |
| 333 | + """ |
| 334 | + if preset not in cls.presets: |
| 335 | + raise ValueError( |
| 336 | + "`preset` must be one of " |
| 337 | + f"""{", ".join(cls.presets)}. Received: {preset}.""" |
| 338 | + ) |
| 339 | + |
| 340 | + tokenizer = DistilBertTokenizer.from_preset(preset) |
| 341 | + |
| 342 | + # Use model's `max_sequence_length` if `sequence_length` unspecified; |
| 343 | + # otherwise check that `sequence_length` not too long. |
| 344 | + metadata = cls.presets[preset] |
| 345 | + max_sequence_length = metadata["config"]["max_sequence_length"] |
| 346 | + if sequence_length is not None: |
| 347 | + if sequence_length > max_sequence_length: |
| 348 | + raise ValueError( |
| 349 | + f"`sequence_length` cannot be longer than `{preset}` " |
| 350 | + f"preset's `max_sequence_length` of {max_sequence_length}. " |
| 351 | + f"Received: {sequence_length}." |
| 352 | + ) |
| 353 | + else: |
| 354 | + sequence_length = max_sequence_length |
| 355 | + |
| 356 | + return cls( |
| 357 | + tokenizer=tokenizer, |
| 358 | + sequence_length=sequence_length, |
| 359 | + truncate=truncate, |
| 360 | + **kwargs, |
| 361 | + ) |
0 commit comments