Skip to content

Commit d51c1d4

Browse files
patrickvonplatensumitd2
authored andcommitted
[Bugfix][Core] Fix tekken edge case for mistral tokenizer (vllm-project#8640)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 0a7ad11 commit d51c1d4

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

tests/models/decoder_only/language/test_mistral.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import pytest
66

7-
from vllm import SamplingParams
7+
from vllm import LLM, SamplingParams
88

99
from ...utils import check_logprobs_close
1010

@@ -16,6 +16,10 @@
1616
]
1717

1818
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
19+
SYMBOLIC_LANG_PROMPTS = [
20+
"勇敢な船乗りについての詩を書く", # japanese
21+
"寫一首關於勇敢的水手的詩", # chinese
22+
]
1923

2024
# for function calling
2125
TOOLS = [{
@@ -131,6 +135,26 @@ def test_mistral_format(
131135
)
132136

133137

138+
@pytest.mark.parametrize("model", MODELS[1:])
139+
@pytest.mark.parametrize("dtype", ["bfloat16"])
140+
@pytest.mark.parametrize("prompt", SYMBOLIC_LANG_PROMPTS)
141+
def test_mistral_symbolic_languages(
142+
model: str,
143+
dtype: str,
144+
prompt: str,
145+
) -> None:
146+
prompt = "hi"
147+
msg = {"role": "user", "content": prompt}
148+
llm = LLM(model=model,
149+
dtype=dtype,
150+
max_model_len=8192,
151+
tokenizer_mode="mistral",
152+
config_format="mistral",
153+
load_format="mistral")
154+
outputs = llm.chat([msg], sampling_params=SAMPLING_PARAMS)
155+
assert "�" not in outputs[0].outputs[0].text.strip()
156+
157+
134158
@pytest.mark.parametrize("dtype", ["bfloat16"])
135159
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
136160
def test_mistral_function_calling(

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,29 @@ def apply_chat_template(self,
175175

176176
def convert_tokens_to_string(self, tokens: List[str]) -> str:
177177
if isinstance(self.tokenizer, Tekkenizer):
178-
return "".join(t for t in tokens
179-
if t not in self.tokenizer._all_special_tokens)
178+
tokens = [
179+
t for t in tokens
180+
if t not in self.tokenizer._all_special_tokens
181+
]
182+
183+
if any(isinstance(t, bytes) for t in tokens):
184+
# we need to encode and decode all tokens again
185+
shift = self.tokenizer.num_special_tokens
186+
byte_tokens = [
187+
t.encode("utf-8") if not isinstance(t, bytes) else t
188+
for t in tokens
189+
]
190+
ids = [
191+
self.tokenizer._tekken_token2id_nospecial[t] + shift
192+
for t in byte_tokens
193+
]
194+
decoded = self.tokenizer.decode(ids)
195+
else:
196+
decoded = "".join(tokens)
180197
else:
181-
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
198+
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
199+
200+
return decoded
182201

183202
def decode(self, ids: Union[List[int], int]) -> str:
184203
if isinstance(ids, int):
@@ -200,4 +219,11 @@ def convert_ids_to_tokens(
200219
self.tokenizer)
201220

202221
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
222+
223+
if any(t.strip() == "�" for t in tokens):
224+
# if any stripped decoded token is undefined
225+
# because it's invalid unicode then pass bytes
226+
# See: https://github.com/vllm-project/vllm/pull/8640
227+
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
228+
203229
return tokens

0 commit comments

Comments
 (0)