Skip to content

Avoid duplicate special tokens in chat formats #1439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 4, 2024
8 changes: 8 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def token_eos(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_eos(self.model)

def token_cls(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_cls(self.model)

def token_sep(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_sep(self.model)

def token_nl(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_nl(self.model)
Expand Down
7 changes: 7 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ctypes
import typing
import fnmatch
import warnings
import multiprocessing

from typing import (
Expand Down Expand Up @@ -1019,6 +1020,12 @@ def _create_completion(
)
model_name: str = model if model is not None else self.model_path

if prompt_tokens[:2] == [self.token_bos()] * 2:
warnings.warn(
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
RuntimeWarning,
)

# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None:
Expand Down
17 changes: 8 additions & 9 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
added_special: bool = False


class ChatFormatter(Protocol):
Expand Down Expand Up @@ -232,7 +233,7 @@ def stop_on_last_token(
return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])

return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)

def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self)
Expand Down Expand Up @@ -548,7 +549,7 @@ def chat_completion_handler(
tools=tools,
tool_choice=tool_choice,
)
prompt = result.prompt
prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
Expand Down Expand Up @@ -655,7 +656,7 @@ def format_autotokenizer(
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
assert isinstance(prompt, str)
# Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)

return format_autotokenizer

Expand Down Expand Up @@ -708,7 +709,7 @@ def format_tokenizer_config(
bos_token=bos_token,
eos_token=eos_token,
)
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)

return format_tokenizer_config

Expand Down Expand Up @@ -918,7 +919,7 @@ def format_llama2(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_roles = dict(user="<s>[INST]", assistant="[/INST]")
_messages = _map_roles(messages, _roles)
system_message = _get_system_message(messages)
Expand All @@ -940,11 +941,10 @@ def format_llama3(
user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
_prompt = _format_no_colon_single("", _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)


Expand Down Expand Up @@ -1229,10 +1229,9 @@ def format_mistral_instruct(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
bos = "<s>"
eos = "</s>"
stop = eos
prompt = bos
prompt = ""
for message in messages:
if (
message["role"] == "user"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ def test_mistral_instruct():
response = llama_chat_format.format_mistral_instruct(
messages=messages,
)
prompt = ("" if response.added_special else "<s>") + response.prompt
reference = chat_formatter.render(
messages=messages,
bos_token="<s>",
eos_token="</s>",
)
assert response.prompt == reference
assert prompt == reference


mistral_7b_tokenizer_config = """{
Expand Down