Skip to content

Commit 26a39bb

Browse files
authored
Add MiniCPM, Deepseek V2 chat template + clean up llama_chat_apply_template_internal (ggml-org#8172)
* tmp_contains * minicpm chat template * add DeepSeek Lite template * change deepseek-lite to deepseek2 * correct code comment * correct code from master branch
1 parent 38373cf commit 26a39bb

File tree

2 files changed

+56
-18
lines changed

2 files changed

+56
-18
lines changed

src/llama.cpp

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19613,24 +19613,27 @@ static int32_t llama_chat_apply_template_internal(
1961319613
std::string & dest, bool add_ass) {
1961419614
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
1961519615
std::stringstream ss;
19616-
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
19616+
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
19617+
return tmpl.find(haystack) != std::string::npos;
19618+
};
19619+
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
1961719620
// chatml template
1961819621
for (auto message : chat) {
1961919622
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
1962019623
}
1962119624
if (add_ass) {
1962219625
ss << "<|im_start|>assistant\n";
1962319626
}
19624-
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
19627+
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
1962519628
// llama2 template and its variants
1962619629
// [variant] support system message
19627-
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
19630+
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
1962819631
// [variant] space before + after response
19629-
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
19632+
bool space_around_response = tmpl_contains("' ' + eos_token");
1963019633
// [variant] add BOS inside history
19631-
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
19634+
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
1963219635
// [variant] trim spaces from the input message
19633-
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
19636+
bool strip_message = tmpl_contains("content.strip()");
1963419637
// construct the prompt
1963519638
bool is_inside_turn = true; // skip BOS at the beginning
1963619639
ss << "[INST] ";
@@ -19656,7 +19659,7 @@ static int32_t llama_chat_apply_template_internal(
1965619659
}
1965719660
}
1965819661
// llama2 templates seem to not care about "add_generation_prompt"
19659-
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
19662+
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
1966019663
// Phi 3
1966119664
for (auto message : chat) {
1966219665
std::string role(message->role);
@@ -19665,15 +19668,15 @@ static int32_t llama_chat_apply_template_internal(
1966519668
if (add_ass) {
1966619669
ss << "<|assistant|>\n";
1966719670
}
19668-
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
19671+
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
1966919672
// zephyr template
1967019673
for (auto message : chat) {
1967119674
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
1967219675
}
1967319676
if (add_ass) {
1967419677
ss << "<|assistant|>\n";
1967519678
}
19676-
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
19679+
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
1967719680
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
1967819681
for (auto message : chat) {
1967919682
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@@ -19682,7 +19685,7 @@ static int32_t llama_chat_apply_template_internal(
1968219685
if (add_ass) {
1968319686
ss << "<s>assistant\n";
1968419687
}
19685-
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
19688+
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
1968619689
// google/gemma-7b-it
1968719690
std::string system_prompt = "";
1968819691
for (auto message : chat) {
@@ -19704,7 +19707,7 @@ static int32_t llama_chat_apply_template_internal(
1970419707
if (add_ass) {
1970519708
ss << "<start_of_turn>model\n";
1970619709
}
19707-
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
19710+
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
1970819711
// OrionStarAI/Orion-14B-Chat
1970919712
std::string system_prompt = "";
1971019713
for (auto message : chat) {
@@ -19724,7 +19727,7 @@ static int32_t llama_chat_apply_template_internal(
1972419727
ss << message->content << "</s>";
1972519728
}
1972619729
}
19727-
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
19730+
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
1972819731
// openchat/openchat-3.5-0106,
1972919732
for (auto message : chat) {
1973019733
std::string role(message->role);
@@ -19738,13 +19741,13 @@ static int32_t llama_chat_apply_template_internal(
1973819741
if (add_ass) {
1973919742
ss << "GPT4 Correct Assistant:";
1974019743
}
19741-
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
19744+
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
1974219745
// eachadea/vicuna-13b-1.1 (and Orca variant)
1974319746
for (auto message : chat) {
1974419747
std::string role(message->role);
1974519748
if (role == "system") {
1974619749
// Orca-Vicuna variant uses a system prefix
19747-
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
19750+
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
1974819751
ss << "SYSTEM: " << message->content << "\n";
1974919752
} else {
1975019753
ss << message->content << "\n\n";
@@ -19758,7 +19761,7 @@ static int32_t llama_chat_apply_template_internal(
1975819761
if (add_ass) {
1975919762
ss << "ASSISTANT:";
1976019763
}
19761-
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
19764+
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
1976219765
// deepseek-ai/deepseek-coder-33b-instruct
1976319766
for (auto message : chat) {
1976419767
std::string role(message->role);
@@ -19773,7 +19776,7 @@ static int32_t llama_chat_apply_template_internal(
1977319776
if (add_ass) {
1977419777
ss << "### Response:\n";
1977519778
}
19776-
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
19779+
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
1977719780
// CohereForAI/c4ai-command-r-plus
1977819781
for (auto message : chat) {
1977919782
std::string role(message->role);
@@ -19788,7 +19791,7 @@ static int32_t llama_chat_apply_template_internal(
1978819791
if (add_ass) {
1978919792
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
1979019793
}
19791-
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
19794+
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
1979219795
// Llama 3
1979319796
for (auto message : chat) {
1979419797
std::string role(message->role);
@@ -19797,6 +19800,33 @@ static int32_t llama_chat_apply_template_internal(
1979719800
if (add_ass) {
1979819801
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
1979919802
}
19803+
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
19804+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
19805+
for (auto message : chat) {
19806+
std::string role(message->role);
19807+
if (role == "user") {
19808+
ss << u8"<用户>";
19809+
ss << trim(message->content);
19810+
ss << "<AI>";
19811+
} else {
19812+
ss << trim(message->content);
19813+
}
19814+
}
19815+
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
19816+
// DeepSeek-V2
19817+
for (auto message : chat) {
19818+
std::string role(message->role);
19819+
if (role == "system") {
19820+
ss << message->content << "\n\n";
19821+
} else if (role == "user") {
19822+
ss << "User: " << message->content << "\n\n";
19823+
} else if (role == "assistant") {
19824+
ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>";
19825+
}
19826+
}
19827+
if (add_ass) {
19828+
ss << "Assistant:";
19829+
}
1980019830
} else {
1980119831
// template not supported
1980219832
return -1;

tests/test-chat-template.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ int main(void) {
5757
//Phi-3-medium
5858
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
5959
//Phi-3-vision
60-
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
60+
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
61+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
62+
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
63+
// DeepSeek-V2
64+
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
6165
};
6266
std::vector<std::string> expected_output = {
6367
// teknium/OpenHermes-2.5-Mistral-7B
@@ -94,6 +98,10 @@ int main(void) {
9498
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
9599
//Phi-3-vision
96100
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
101+
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
102+
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
103+
// DeepSeek-V2
104+
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
97105
};
98106
std::vector<char> formatted_chat(1024);
99107
int32_t res;

0 commit comments

Comments
 (0)