Skip to content

Commit a203a0b

Browse files
committed
Read PAD, BOS, EOS from tokenizer_config.json if not defined in config.json
1 parent d3184ec commit a203a0b

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

exllamav2/config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def prepare(self, no_tensors: bool = False):
148148

149149
# Vocab params
150150

151-
self.bos_token_id = read(read_config, int, "bos_token_id", 1)
152-
self.eos_token_id = read(read_config, int, "eos_token_id", 2)
153-
self.pad_token_id = read(read_config, int, "pad_token_id", 0)
151+
self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
152+
self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2
153+
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
154154
self.vocab_size = read(read_config, int, "vocab_size")
155155

156156
# Standard params

exllamav2/tokenizer/tokenizer.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def tokenizer(self):
6060
max_cached_strings: int
6161
actual_vocab_size: int
6262

63+
tokenizer_config_dict: dict | None
6364

6465
def __init__(self, config, lazy_init = False, force_json = False):
6566
"""
@@ -120,6 +121,15 @@ def __init__(self, config, lazy_init = False, force_json = False):
120121
else:
121122
self.unspecial_piece_to_id[v["content"]] = v["id"]
122123

124+
# Attempt to load tokenizer_config.json
125+
126+
tokenizer_config_json_path = os.path.join(self.config.model_dir, "tokenizer_config.json")
127+
if os.path.exists(tokenizer_config_json_path):
128+
with open(tokenizer_config_json_path, encoding = "utf8") as f:
129+
self.tokenizer_config_dict = json.load(f)
130+
else:
131+
self.tokenizer_config_dict = None
132+
123133
# Add tokens from added_tokens.json if present, assume they're all special
124134

125135
added_tokens_path = os.path.join(self.config.model_dir, "added_tokens.json")
@@ -149,17 +159,34 @@ def __init__(self, config, lazy_init = False, force_json = False):
149159
self.unk_token_id = self.tokenizer_model.unk_id()
150160
self.eos_token_id = config.eos_token_id
151161
self.bos_token_id = config.bos_token_id
162+
self.pad_token_id = config.pad_token_id
163+
164+
# If model config doesn't specify BOS and EOS tokens, try to load from tokenizer config
165+
166+
def get_default_token_id(config_key: str, current: int | None, default: int):
167+
if current is not None: return current
168+
if self.tokenizer_config_dict is not None and config_key in self.tokenizer_config_dict:
169+
st = self.tokenizer_config_dict[config_key]
170+
if st is None: return None
171+
return self.tokenizer_model.piece_to_id(st)
172+
else:
173+
return default
174+
175+
self.pad_token_id = get_default_token_id("pad_token", self.pad_token_id, 0)
176+
self.bos_token_id = get_default_token_id("bos_token", self.bos_token_id, 1)
177+
self.eos_token_id = get_default_token_id("eos_token", self.eos_token_id, 2)
152178

153179
# Get control token strings
154180

155181
self.unk_token = (self.tokenizer_model.unk_token() or self.extended_id_to_piece.get(self.unk_token_id, None)) or self.tokenizer_model.id_to_piece(self.unk_token_id)
156182
self.bos_token = (self.tokenizer_model.bos_token() or self.extended_id_to_piece.get(self.bos_token_id, None)) or self.tokenizer_model.id_to_piece(self.bos_token_id)
157183
self.eos_token = (self.tokenizer_model.eos_token() or self.extended_id_to_piece.get(self.eos_token_id, None)) or self.tokenizer_model.id_to_piece(self.eos_token_id)
158184

159-
# Some tokenizers use token ID zero for text but don't explicitly define a padding token but provide one anyway
185+
# Use "<pad>" or EOS token as fallback for padding token
160186

161-
pad_test = self.tokenizer_model.piece_to_id("<pad>")
162-
self.pad_token_id = pad_test or self.eos_token_id
187+
if self.pad_token_id is None:
188+
pad_test = self.tokenizer_model.piece_to_id("<pad>")
189+
self.pad_token_id = pad_test or self.eos_token_id
163190

164191
# Special case if <unk> and <pad> have the same ID
165192

@@ -181,7 +208,7 @@ def __init__(self, config, lazy_init = False, force_json = False):
181208
self.actual_vocab_size = 1 + max(
182209
list(self.extended_id_to_piece.keys()) + \
183210
list(self.unspecial_id_to_piece.keys()) + \
184-
[self.tokenizer_model.vocab_size()] # max([]) is illegal
211+
[self.tokenizer_model.vocab_size() - 1]
185212
)
186213

187214
# Useful token IDs

0 commit comments

Comments
 (0)