Skip to content

Commit 027f7bc

Browse files
CISCabetlen
andauthored
fix: Avoid duplicate special tokens in chat formats (#1439)
* Templates sometimes have BOS in them, remove duplicate * tokenize chat format prompts before completion This is to ensure that we don't duplicate any special tokens. Hopefully I amended the existing formats correctly? * updated comment * corrected a few * add some missing internals * proper bos/eos detection * just let tokenizer do the job * typo-- * align test with new response * changed to a warning * move to another PR * Use python warnings module --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent 951e39c commit 027f7bc

File tree

4 files changed

+25
-10
lines changed

4 files changed

+25
-10
lines changed

llama_cpp/_internals.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def token_eos(self) -> int:
142142
assert self.model is not None
143143
return llama_cpp.llama_token_eos(self.model)
144144

145+
def token_cls(self) -> int:
146+
assert self.model is not None
147+
return llama_cpp.llama_token_cls(self.model)
148+
149+
def token_sep(self) -> int:
150+
assert self.model is not None
151+
return llama_cpp.llama_token_sep(self.model)
152+
145153
def token_nl(self) -> int:
146154
assert self.model is not None
147155
return llama_cpp.llama_token_nl(self.model)

llama_cpp/llama.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ctypes
99
import typing
1010
import fnmatch
11+
import warnings
1112
import multiprocessing
1213

1314
from typing import (
@@ -1019,6 +1020,12 @@ def _create_completion(
10191020
)
10201021
model_name: str = model if model is not None else self.model_path
10211022

1023+
if prompt_tokens[:2] == [self.token_bos()] * 2:
1024+
warnings.warn(
1025+
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
1026+
RuntimeWarning,
1027+
)
1028+
10221029
# NOTE: This likely doesn't work correctly for the first token in the prompt
10231030
# because of the extra space added to the start of the prompt_tokens
10241031
if logit_bias is not None:

llama_cpp/llama_chat_format.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class ChatFormatterResponse:
160160
prompt: str
161161
stop: Optional[Union[str, List[str]]] = None
162162
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
163+
added_special: bool = False
163164

164165

165166
class ChatFormatter(Protocol):
@@ -232,7 +233,7 @@ def stop_on_last_token(
232233
return tokens[-1] in self.stop_token_ids
233234
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
234235

235-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
236+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)
236237

237238
def to_chat_handler(self) -> LlamaChatCompletionHandler:
238239
return chat_formatter_to_chat_completion_handler(self)
@@ -548,7 +549,7 @@ def chat_completion_handler(
548549
tools=tools,
549550
tool_choice=tool_choice,
550551
)
551-
prompt = result.prompt
552+
prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
552553
if result.stop is not None:
553554
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
554555
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
@@ -655,7 +656,7 @@ def format_autotokenizer(
655656
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
656657
assert isinstance(prompt, str)
657658
# Return formatted prompt and eos token by default
658-
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
659+
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)
659660

660661
return format_autotokenizer
661662

@@ -708,7 +709,7 @@ def format_tokenizer_config(
708709
bos_token=bos_token,
709710
eos_token=eos_token,
710711
)
711-
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
712+
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)
712713

713714
return format_tokenizer_config
714715

@@ -918,7 +919,7 @@ def format_llama2(
918919
messages: List[llama_types.ChatCompletionRequestMessage],
919920
**kwargs: Any,
920921
) -> ChatFormatterResponse:
921-
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
922+
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
922923
_roles = dict(user="<s>[INST]", assistant="[/INST]")
923924
_messages = _map_roles(messages, _roles)
924925
system_message = _get_system_message(messages)
@@ -940,11 +941,10 @@ def format_llama3(
940941
user="<|start_header_id|>user<|end_header_id|>\n\n",
941942
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
942943
)
943-
_begin_token = "<|begin_of_text|>"
944944
_sep = "<|eot_id|>"
945945
_messages = _map_roles(messages, _roles)
946946
_messages.append((_roles["assistant"], None))
947-
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
947+
_prompt = _format_no_colon_single("", _messages, _sep)
948948
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
949949

950950

@@ -1229,10 +1229,9 @@ def format_mistral_instruct(
12291229
messages: List[llama_types.ChatCompletionRequestMessage],
12301230
**kwargs: Any,
12311231
) -> ChatFormatterResponse:
1232-
bos = "<s>"
12331232
eos = "</s>"
12341233
stop = eos
1235-
prompt = bos
1234+
prompt = ""
12361235
for message in messages:
12371236
if (
12381237
message["role"] == "user"

tests/test_llama_chat_format.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ def test_mistral_instruct():
2121
response = llama_chat_format.format_mistral_instruct(
2222
messages=messages,
2323
)
24+
prompt = ("" if response.added_special else "<s>") + response.prompt
2425
reference = chat_formatter.render(
2526
messages=messages,
2627
bos_token="<s>",
2728
eos_token="</s>",
2829
)
29-
assert response.prompt == reference
30+
assert prompt == reference
3031

3132

3233
mistral_7b_tokenizer_config = """{

0 commit comments

Comments
 (0)