Skip to content

Commit c817793

Browse files
committed
feat: improve function calling
1 parent a420f96 commit c817793

File tree

2 files changed

+728
-171
lines changed

2 files changed

+728
-171
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 215 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,9 +2286,9 @@ def __call__(
22862286
stream=stream,
22872287
)
22882288

2289-
2290-
@register_chat_completion_handler("chatml-function-calling")
2291-
def chatml_function_calling(
2289+
def base_function_calling(
2290+
function_calling_template,
2291+
end_token,
22922292
llama: llama.Llama,
22932293
messages: List[llama_types.ChatCompletionRequestMessage],
22942294
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
@@ -2320,65 +2320,13 @@ def chatml_function_calling(
23202320
) -> Union[
23212321
llama_types.CreateChatCompletionResponse,
23222322
Iterator[llama_types.CreateChatCompletionStreamResponse],
2323-
]:
2324-
print(logprobs)
2325-
function_calling_template = (
2326-
"{% for message in messages %}"
2327-
"<|im_start|>{{ message.role }}\n"
2328-
# System message
2329-
"{% if message.role == 'system' %}"
2330-
"{{ message.content }}"
2331-
"{% if tool_calls %}"
2332-
"\n\nYou have access to the following functions:\n"
2333-
"{% for tool in tools %}"
2334-
"\nfunctions.{{ tool.function.name }}:\n"
2335-
"{{ tool.function.parameters | tojson }}"
2336-
"\n{% endfor %}"
2337-
"\n\nYou can respond to users messages with either a single message or one or more function calls."
2338-
"\n\nTo respond with a message begin the message with 'message:', use the following format:"
2339-
"\n\nmessage:"
2340-
"\n<message>"
2341-
"\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2342-
"\n\nfunctions.<function_name>:"
2343-
'\n{ "arg1": "value1", "arg2": "value2" }'
2344-
"\nfunctions.<function_name>:"
2345-
'\n{ "arg1": "value1", "arg2": "value2" }'
2346-
"{% endif %}"
2347-
"<|im_end|>\n"
2348-
"{% endif %}"
2349-
# User message
2350-
"{% if message.role == 'user' %}"
2351-
"{{ message.content }}"
2352-
"<|im_end|>\n"
2353-
"{% endif %}"
2354-
# Assistant message
2355-
"{% if message.role == 'assistant' %}"
2356-
## Reglar message
2357-
"{% if message.content and message.content | length > 0 %}"
2358-
"{% if tool_calls %}"
2359-
"message:\n"
2360-
"{% endif %}"
2361-
"{{ message.content }}"
2362-
"<|im_end|>\n"
2363-
"{% endif %}"
2364-
## Function calls
2365-
"{% if 'tool_calls' in message %}"
2366-
"{% for tool_call in message.tool_calls %}"
2367-
"functions.{{ tool_call.function.name }}:\n"
2368-
"{{ tool_call.function.arguments }}"
2369-
"{% endfor %}"
2370-
"<|im_end|>\n"
2371-
"{% endif %}"
2372-
"{% endif %}"
2373-
"{% endfor %}"
2374-
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
2375-
)
2323+
]:
2324+
23762325
template_renderer = jinja2.Environment(
23772326
loader=jinja2.BaseLoader(),
23782327
autoescape=jinja2.select_autoescape(["html", "xml"]),
2379-
undefined=jinja2.StrictUndefined,
2328+
undefined=jinja2.DebugUndefined,
23802329
).from_string(function_calling_template)
2381-
23822330
# Convert legacy functions to tools
23832331
if functions is not None:
23842332
tools = [
@@ -2403,8 +2351,7 @@ def chatml_function_calling(
24032351
},
24042352
}
24052353

2406-
stop = [stop, "<|im_end|>"] if isinstance(stop, str) else stop + ["<|im_end|>"] if stop else ["<|im_end|>"]
2407-
2354+
stop = [stop, end_token] if isinstance(stop, str) else stop + [end_token] if stop else [end_token]
24082355
# Case 1: No tool choice by user
24092356
if (
24102357
tool_choice is None
@@ -2418,7 +2365,6 @@ def chatml_function_calling(
24182365
tool_calls=None,
24192366
add_generation_prompt=True,
24202367
)
2421-
24222368
if response_format is not None and response_format["type"] == "json_object":
24232369
grammar = _grammar_for_response_format(response_format)
24242370

@@ -2506,14 +2452,18 @@ def chatml_function_calling(
25062452
function_names = " | ".join(
25072453
[f'''"functions.{tool['function']['name']}:"''' for tool in tools]
25082454
)
2455+
25092456
initial_gbnf_tool_grammar = (
25102457
"""root ::= functions | "message:"\n"""
25112458
f"""functions ::= {function_names}\n"""
25122459
)
2460+
25132461
follow_up_gbnf_tool_grammar = (
2514-
"""root ::= functions | "<|im_end|>"\n"""
2462+
f"""root ::= functions | "</done>"\n"""
25152463
f"""functions ::= {function_names}\n"""
25162464
)
2465+
2466+
25172467
prompt = template_renderer.render(
25182468
messages=messages,
25192469
tools=tools,
@@ -2522,14 +2472,14 @@ def chatml_function_calling(
25222472
)
25232473
completion_or_chunks = llama.create_completion(
25242474
prompt=prompt,
2525-
temperature=0,
2475+
temperature=temperature,
25262476
top_p=top_p,
25272477
top_k=top_k,
25282478
min_p=min_p,
25292479
typical_p=typical_p,
25302480
stream=False,
2531-
stop=[":"],
2532-
max_tokens=None,
2481+
stop=stop,
2482+
max_tokens=max_tokens,
25332483
presence_penalty=presence_penalty,
25342484
frequency_penalty=frequency_penalty,
25352485
repeat_penalty=repeat_penalty,
@@ -2555,7 +2505,7 @@ def chatml_function_calling(
25552505
min_p=min_p,
25562506
typical_p=typical_p,
25572507
stream=stream,
2558-
stop=["<|im_end|>"],
2508+
stop=["</s>"],
25592509
logprobs=top_logprobs if logprobs else None,
25602510
max_tokens=None,
25612511
presence_penalty=presence_penalty,
@@ -2567,15 +2517,12 @@ def chatml_function_calling(
25672517
mirostat_eta=mirostat_eta,
25682518
model=model,
25692519
logits_processor=logits_processor,
2570-
grammar=llama_grammar.LlamaGrammar.from_string(
2571-
follow_up_gbnf_tool_grammar, verbose=llama.verbose
2572-
),
25732520
),
25742521
stream=stream,
25752522
)
25762523

25772524
# One or more function calls
2578-
tool_name = text[len("functions.") :]
2525+
tool_name = text[len("functions.") :].replace(":", "")
25792526
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
25802527
if not stream:
25812528
completions: List[llama_types.CreateCompletionResponse] = []
@@ -2621,7 +2568,6 @@ def chatml_function_calling(
26212568
completions_tool_name.append(tool_name)
26222569
prompt += completion_or_chunks["choices"][0]["text"]
26232570
prompt += "\n"
2624-
26252571
response = llama.create_completion(
26262572
prompt=prompt,
26272573
temperature=temperature,
@@ -2644,10 +2590,14 @@ def chatml_function_calling(
26442590
grammar=llama_grammar.LlamaGrammar.from_string(
26452591
follow_up_gbnf_tool_grammar, verbose=llama.verbose
26462592
),
2593+
26472594
)
2595+
26482596
response = cast(llama_types.CreateCompletionResponse, response)
26492597

2650-
tool_name = response["choices"][0]["text"][len("functions.") :]
2598+
if response["choices"][0]["text"] == "</done>":
2599+
break
2600+
tool_name = response["choices"][0]["text"][len("functions.") :].replace(":", "")
26512601
tool = next(
26522602
(tool for tool in tools if tool["function"]["name"] == tool_name), None
26532603
)
@@ -2710,3 +2660,196 @@ def chatml_function_calling(
27102660
}
27112661

27122662
raise ValueError("Automatic streaming tool choice is not supported")
2663+
2664+
2665+
@register_chat_completion_handler("chatml-function-calling")
2666+
def chatml_function_calling(
2667+
llama: llama.Llama,
2668+
messages: List[llama_types.ChatCompletionRequestMessage],
2669+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
2670+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
2671+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
2672+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
2673+
temperature: float = 0.2,
2674+
top_p: float = 0.95,
2675+
top_k: int = 40,
2676+
min_p: float = 0.05,
2677+
typical_p: float = 1.0,
2678+
stream: bool = False,
2679+
stop: Optional[Union[str, List[str]]] = [],
2680+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
2681+
max_tokens: Optional[int] = None,
2682+
presence_penalty: float = 0.0,
2683+
frequency_penalty: float = 0.0,
2684+
repeat_penalty: float = 1.1,
2685+
tfs_z: float = 1.0,
2686+
mirostat_mode: int = 0,
2687+
mirostat_tau: float = 5.0,
2688+
mirostat_eta: float = 0.1,
2689+
model: Optional[str] = None,
2690+
logits_processor: Optional[llama.LogitsProcessorList] = None,
2691+
grammar: Optional[llama.LlamaGrammar] = None,
2692+
logprobs: Optional[bool] = None,
2693+
top_logprobs: Optional[int] = None,
2694+
**kwargs, # type: ignore
2695+
) -> Union[
2696+
llama_types.CreateChatCompletionResponse,
2697+
Iterator[llama_types.CreateChatCompletionStreamResponse],
2698+
]:
2699+
function_calling_template = (
2700+
"{% for message in messages %}"
2701+
"<|im_start|>{{ message.role }}\n"
2702+
# System message
2703+
"{% if message.role == 'system' %}"
2704+
"{{ message.content }}"
2705+
"{% if tool_calls %}"
2706+
"\n\nYou have access to the following functions:\n"
2707+
"{% for tool in tools %}"
2708+
"\nfunctions.{{ tool.function.name }}:\n"
2709+
"{{ tool.function.parameters | tojson }}"
2710+
"\n{% endfor %}"
2711+
"\n\nYou can respond to users messages with either a single message or one or more function calls."
2712+
"\n\nTo respond with a message begin the message with 'message:', use the following format:"
2713+
"\n\nmessage:"
2714+
"\n<message>"
2715+
"\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2716+
"\n\nfunctions.<function_name>:"
2717+
'\n{ "arg1": "value1", "arg2": "value2" };'
2718+
"\nfunctions.<function_name>:"
2719+
'\n{ "arg1": "value1", "arg2": "value2" }'
2720+
"\n\nWhen you are done with the function calls, end the message with </done>."
2721+
"{% endif %}"
2722+
"<|im_end|>\n"
2723+
"{% endif %}"
2724+
# User message
2725+
"{% if message.role == 'user' %}"
2726+
"{{ message.content }}"
2727+
"<|im_end|>\n"
2728+
"{% endif %}"
2729+
# Assistant message
2730+
"{% if message.role == 'assistant' %}"
2731+
## Reglar message
2732+
"{% if message.content and message.content | length > 0 %}"
2733+
"{% if tool_calls %}"
2734+
"message:\n"
2735+
"{% endif %}"
2736+
"{{ message.content }}"
2737+
"<|im_end|>\n"
2738+
"{% endif %}"
2739+
## Function calls
2740+
"{% if tool_calls %}"
2741+
"{% for tool_call in message.tool_calls %}"
2742+
"functions.{{ tool_call.function.name }}:\n"
2743+
"{{ (tool_call.arguments | default('{}') | tojson) }}"
2744+
"{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call
2745+
"{% endfor %}"
2746+
"<|im_end|>\n"
2747+
"{% endif %}"
2748+
"{% endif %}"
2749+
# Tool message (treated as Assistant response)
2750+
"{% if message.role == 'tool' %}"
2751+
"ASSISTANT:\n"
2752+
"Function response: {{ message.content | default('No response available') }}"
2753+
"<|im_end|>\n"
2754+
"{% endif %}"
2755+
"{% endfor %}"
2756+
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
2757+
)
2758+
return base_function_calling(end_token="<|im_end|>",
2759+
**locals())
2760+
2761+
@register_chat_completion_handler("vicuna-function-calling")
2762+
def vicuna_function_calling(
2763+
llama: llama.Llama,
2764+
messages: List[llama_types.ChatCompletionRequestMessage],
2765+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
2766+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
2767+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
2768+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
2769+
temperature: float = 0.2,
2770+
top_p: float = 0.95,
2771+
top_k: int = 40,
2772+
min_p: float = 0.05,
2773+
typical_p: float = 1.0,
2774+
stream: bool = False,
2775+
stop: Optional[Union[str, List[str]]] = [],
2776+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
2777+
max_tokens: Optional[int] = None,
2778+
presence_penalty: float = 0.0,
2779+
frequency_penalty: float = 0.0,
2780+
repeat_penalty: float = 1.1,
2781+
tfs_z: float = 1.0,
2782+
mirostat_mode: int = 0,
2783+
mirostat_tau: float = 5.0,
2784+
mirostat_eta: float = 0.1,
2785+
model: Optional[str] = None,
2786+
logits_processor: Optional[llama.LogitsProcessorList] = None,
2787+
grammar: Optional[llama.LlamaGrammar] = None,
2788+
logprobs: Optional[bool] = None,
2789+
top_logprobs: Optional[int] = None,
2790+
**kwargs, # type: ignore
2791+
) -> Union[
2792+
llama_types.CreateChatCompletionResponse,
2793+
Iterator[llama_types.CreateChatCompletionStreamResponse],
2794+
]:
2795+
function_calling_template = (
2796+
"{% for message in messages %}"
2797+
"{{ message.role.upper() }}\n" # Vicuna uses upper case for roles
2798+
# System message
2799+
"{% if message.role == 'system' %}"
2800+
"{{ message.content }}"
2801+
"{% if tool_calls %}"
2802+
"\n\nYou have access to the following functions:\n"
2803+
"{% for tool in tools %}"
2804+
"\nfunctions.{{ tool.function.name }}:\n"
2805+
"{{ tool.function.parameters | tojson }}"
2806+
"\n{% endfor %}"
2807+
"\n\nYou can respond to users messages with either a single message or multiple function calls."
2808+
"\n\nTo respond with a message begin the message with 'message:', use the following format:"
2809+
"\n\nmessage:"
2810+
"\n<message>"
2811+
"\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2812+
"\n\nfunctions.<function_name>:"
2813+
'\n{ "arg1": "value1", "arg2": "value2" };'
2814+
"\nfunctions.<another_function_name>:"
2815+
'\n{ "arg1": "value3", "arg2": "value4" }'
2816+
"\n\nWhen you are done with the function calls, end the message with </done>."
2817+
"{% endif %}"
2818+
"</s>\n"
2819+
"{% endif %}"
2820+
# User message
2821+
"{% if message.role == 'user' %}"
2822+
"{{ message.content }}"
2823+
"</s>\n"
2824+
"{% endif %}"
2825+
# Assistant message
2826+
"{% if message.role == 'assistant' %}"
2827+
## Regular message
2828+
"{% if message.content and message.content | length > 0 %}"
2829+
"{% if tool_calls %}"
2830+
"message:\n"
2831+
"{% endif %}"
2832+
"{{ message.content }}"
2833+
"</s>\n"
2834+
"{% endif %}"
2835+
## Function calls
2836+
"{% if tool_calls %}"
2837+
"{% for tool_call in message.tool_calls %}"
2838+
"functions.{{ tool_call.function.name }}:\n"
2839+
"{{ (tool_call.arguments | default('{}') | tojson) }}"
2840+
"{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call
2841+
"{% endfor %}"
2842+
"</s>\n"
2843+
"{% endif %}"
2844+
"{% endif %}"
2845+
# Tool message (treated as Assistant response)
2846+
"{% if message.role == 'tool' %}"
2847+
"ASSISTANT:\n"
2848+
"Function response: {{ message.content | default('No response available') }}"
2849+
"</s>\n"
2850+
"{% endif %}"
2851+
"{% endfor %}"
2852+
"{% if add_generation_prompt %}</s>ASSISTANT\n{% endif %}" # Vicuna adds the role for prompt continuation
2853+
)
2854+
return base_function_calling(end_token="</s>",
2855+
**locals())

0 commit comments

Comments
 (0)