@@ -60,6 +60,7 @@ def tokenizer(self):
60
60
max_cached_strings : int
61
61
actual_vocab_size : int
62
62
63
+ tokenizer_config_dict : dict | None
63
64
64
65
def __init__ (self , config , lazy_init = False , force_json = False ):
65
66
"""
@@ -120,6 +121,15 @@ def __init__(self, config, lazy_init = False, force_json = False):
120
121
else :
121
122
self .unspecial_piece_to_id [v ["content" ]] = v ["id" ]
122
123
124
+ # Attempt to load tokenizer_config.json
125
+
126
+ tokenizer_config_json_path = os .path .join (self .config .model_dir , "tokenizer_config.json" )
127
+ if os .path .exists (tokenizer_config_json_path ):
128
+ with open (tokenizer_config_json_path , encoding = "utf8" ) as f :
129
+ self .tokenizer_config_dict = json .load (f )
130
+ else :
131
+ self .tokenizer_config_dict = None
132
+
123
133
# Add tokens from added_tokens.json if present, assume they're all special
124
134
125
135
added_tokens_path = os .path .join (self .config .model_dir , "added_tokens.json" )
@@ -149,17 +159,34 @@ def __init__(self, config, lazy_init = False, force_json = False):
149
159
self .unk_token_id = self .tokenizer_model .unk_id ()
150
160
self .eos_token_id = config .eos_token_id
151
161
self .bos_token_id = config .bos_token_id
162
+ self .pad_token_id = config .pad_token_id
163
+
164
+ # If model config doesn't specify BOS and EOS tokens, try to load from tokenizer config
165
+
166
+ def get_default_token_id (config_key : str , current : int | None , default : int ):
167
+ if current is not None : return current
168
+ if self .tokenizer_config_dict is not None and config_key in self .tokenizer_config_dict :
169
+ st = self .tokenizer_config_dict [config_key ]
170
+ if st is None : return None
171
+ return self .tokenizer_model .piece_to_id (st )
172
+ else :
173
+ return default
174
+
175
+ self .pad_token_id = get_default_token_id ("pad_token" , self .pad_token_id , 0 )
176
+ self .bos_token_id = get_default_token_id ("bos_token" , self .bos_token_id , 1 )
177
+ self .eos_token_id = get_default_token_id ("eos_token" , self .eos_token_id , 2 )
152
178
153
179
# Get control token strings
154
180
155
181
self .unk_token = (self .tokenizer_model .unk_token () or self .extended_id_to_piece .get (self .unk_token_id , None )) or self .tokenizer_model .id_to_piece (self .unk_token_id )
156
182
self .bos_token = (self .tokenizer_model .bos_token () or self .extended_id_to_piece .get (self .bos_token_id , None )) or self .tokenizer_model .id_to_piece (self .bos_token_id )
157
183
self .eos_token = (self .tokenizer_model .eos_token () or self .extended_id_to_piece .get (self .eos_token_id , None )) or self .tokenizer_model .id_to_piece (self .eos_token_id )
158
184
159
- # Some tokenizers use token ID zero for text but don't explicitly define a padding token but provide one anyway
185
+ # Use "<pad>" or EOS token as fallback for padding token
160
186
161
- pad_test = self .tokenizer_model .piece_to_id ("<pad>" )
162
- self .pad_token_id = pad_test or self .eos_token_id
187
+ if self .pad_token_id is None :
188
+ pad_test = self .tokenizer_model .piece_to_id ("<pad>" )
189
+ self .pad_token_id = pad_test or self .eos_token_id
163
190
164
191
# Special case if <unk> and <pad> have the same ID
165
192
@@ -181,7 +208,7 @@ def __init__(self, config, lazy_init = False, force_json = False):
181
208
self .actual_vocab_size = 1 + max (
182
209
list (self .extended_id_to_piece .keys ()) + \
183
210
list (self .unspecial_id_to_piece .keys ()) + \
184
- [self .tokenizer_model .vocab_size ()] # max([]) is illegal
211
+ [self .tokenizer_model .vocab_size () - 1 ]
185
212
)
186
213
187
214
# Useful token IDs
0 commit comments