@@ -2286,9 +2286,9 @@ def __call__(
2286
2286
stream = stream ,
2287
2287
)
2288
2288
2289
-
2290
- @ register_chat_completion_handler ( "chatml-function-calling" )
2291
- def chatml_function_calling (
2289
+ def base_function_calling (
2290
+ function_calling_template ,
2291
+ end_token ,
2292
2292
llama : llama .Llama ,
2293
2293
messages : List [llama_types .ChatCompletionRequestMessage ],
2294
2294
functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
@@ -2320,65 +2320,13 @@ def chatml_function_calling(
2320
2320
) -> Union [
2321
2321
llama_types .CreateChatCompletionResponse ,
2322
2322
Iterator [llama_types .CreateChatCompletionStreamResponse ],
2323
- ]:
2324
- print (logprobs )
2325
- function_calling_template = (
2326
- "{% for message in messages %}"
2327
- "<|im_start|>{{ message.role }}\n "
2328
- # System message
2329
- "{% if message.role == 'system' %}"
2330
- "{{ message.content }}"
2331
- "{% if tool_calls %}"
2332
- "\n \n You have access to the following functions:\n "
2333
- "{% for tool in tools %}"
2334
- "\n functions.{{ tool.function.name }}:\n "
2335
- "{{ tool.function.parameters | tojson }}"
2336
- "\n {% endfor %}"
2337
- "\n \n You can respond to users messages with either a single message or one or more function calls."
2338
- "\n \n To respond with a message begin the message with 'message:', use the following format:"
2339
- "\n \n message:"
2340
- "\n <message>"
2341
- "\n \n To respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2342
- "\n \n functions.<function_name>:"
2343
- '\n { "arg1": "value1", "arg2": "value2" }'
2344
- "\n functions.<function_name>:"
2345
- '\n { "arg1": "value1", "arg2": "value2" }'
2346
- "{% endif %}"
2347
- "<|im_end|>\n "
2348
- "{% endif %}"
2349
- # User message
2350
- "{% if message.role == 'user' %}"
2351
- "{{ message.content }}"
2352
- "<|im_end|>\n "
2353
- "{% endif %}"
2354
- # Assistant message
2355
- "{% if message.role == 'assistant' %}"
2356
- ## Reglar message
2357
- "{% if message.content and message.content | length > 0 %}"
2358
- "{% if tool_calls %}"
2359
- "message:\n "
2360
- "{% endif %}"
2361
- "{{ message.content }}"
2362
- "<|im_end|>\n "
2363
- "{% endif %}"
2364
- ## Function calls
2365
- "{% if 'tool_calls' in message %}"
2366
- "{% for tool_call in message.tool_calls %}"
2367
- "functions.{{ tool_call.function.name }}:\n "
2368
- "{{ tool_call.function.arguments }}"
2369
- "{% endfor %}"
2370
- "<|im_end|>\n "
2371
- "{% endif %}"
2372
- "{% endif %}"
2373
- "{% endfor %}"
2374
- "{% if add_generation_prompt %}<|im_start|>assistant\n {% endif %}"
2375
- )
2323
+ ]:
2324
+
2376
2325
template_renderer = jinja2 .Environment (
2377
2326
loader = jinja2 .BaseLoader (),
2378
2327
autoescape = jinja2 .select_autoescape (["html" , "xml" ]),
2379
- undefined = jinja2 .StrictUndefined ,
2328
+ undefined = jinja2 .DebugUndefined ,
2380
2329
).from_string (function_calling_template )
2381
-
2382
2330
# Convert legacy functions to tools
2383
2331
if functions is not None :
2384
2332
tools = [
@@ -2403,8 +2351,7 @@ def chatml_function_calling(
2403
2351
},
2404
2352
}
2405
2353
2406
- stop = [stop , "<|im_end|>" ] if isinstance (stop , str ) else stop + ["<|im_end|>" ] if stop else ["<|im_end|>" ]
2407
-
2354
+ stop = [stop , end_token ] if isinstance (stop , str ) else stop + [end_token ] if stop else [end_token ]
2408
2355
# Case 1: No tool choice by user
2409
2356
if (
2410
2357
tool_choice is None
@@ -2418,7 +2365,6 @@ def chatml_function_calling(
2418
2365
tool_calls = None ,
2419
2366
add_generation_prompt = True ,
2420
2367
)
2421
-
2422
2368
if response_format is not None and response_format ["type" ] == "json_object" :
2423
2369
grammar = _grammar_for_response_format (response_format )
2424
2370
@@ -2506,14 +2452,18 @@ def chatml_function_calling(
2506
2452
function_names = " | " .join (
2507
2453
[f'''"functions.{ tool ['function' ]['name' ]} :"''' for tool in tools ]
2508
2454
)
2455
+
2509
2456
initial_gbnf_tool_grammar = (
2510
2457
"""root ::= functions | "message:"\n """
2511
2458
f"""functions ::= { function_names } \n """
2512
2459
)
2460
+
2513
2461
follow_up_gbnf_tool_grammar = (
2514
- """root ::= functions | "<|im_end| >"\n """
2462
+ f """root ::= functions | "</done >"\n """
2515
2463
f"""functions ::= { function_names } \n """
2516
2464
)
2465
+
2466
+
2517
2467
prompt = template_renderer .render (
2518
2468
messages = messages ,
2519
2469
tools = tools ,
@@ -2522,14 +2472,14 @@ def chatml_function_calling(
2522
2472
)
2523
2473
completion_or_chunks = llama .create_completion (
2524
2474
prompt = prompt ,
2525
- temperature = 0 ,
2475
+ temperature = temperature ,
2526
2476
top_p = top_p ,
2527
2477
top_k = top_k ,
2528
2478
min_p = min_p ,
2529
2479
typical_p = typical_p ,
2530
2480
stream = False ,
2531
- stop = [ ":" ] ,
2532
- max_tokens = None ,
2481
+ stop = stop ,
2482
+ max_tokens = max_tokens ,
2533
2483
presence_penalty = presence_penalty ,
2534
2484
frequency_penalty = frequency_penalty ,
2535
2485
repeat_penalty = repeat_penalty ,
@@ -2555,7 +2505,7 @@ def chatml_function_calling(
2555
2505
min_p = min_p ,
2556
2506
typical_p = typical_p ,
2557
2507
stream = stream ,
2558
- stop = ["<|im_end| >" ],
2508
+ stop = ["</s >" ],
2559
2509
logprobs = top_logprobs if logprobs else None ,
2560
2510
max_tokens = None ,
2561
2511
presence_penalty = presence_penalty ,
@@ -2567,15 +2517,12 @@ def chatml_function_calling(
2567
2517
mirostat_eta = mirostat_eta ,
2568
2518
model = model ,
2569
2519
logits_processor = logits_processor ,
2570
- grammar = llama_grammar .LlamaGrammar .from_string (
2571
- follow_up_gbnf_tool_grammar , verbose = llama .verbose
2572
- ),
2573
2520
),
2574
2521
stream = stream ,
2575
2522
)
2576
2523
2577
2524
# One or more function calls
2578
- tool_name = text [len ("functions." ) :]
2525
+ tool_name = text [len ("functions." ) :]. replace ( ":" , "" )
2579
2526
tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
2580
2527
if not stream :
2581
2528
completions : List [llama_types .CreateCompletionResponse ] = []
@@ -2621,7 +2568,6 @@ def chatml_function_calling(
2621
2568
completions_tool_name .append (tool_name )
2622
2569
prompt += completion_or_chunks ["choices" ][0 ]["text" ]
2623
2570
prompt += "\n "
2624
-
2625
2571
response = llama .create_completion (
2626
2572
prompt = prompt ,
2627
2573
temperature = temperature ,
@@ -2644,10 +2590,14 @@ def chatml_function_calling(
2644
2590
grammar = llama_grammar .LlamaGrammar .from_string (
2645
2591
follow_up_gbnf_tool_grammar , verbose = llama .verbose
2646
2592
),
2593
+
2647
2594
)
2595
+
2648
2596
response = cast (llama_types .CreateCompletionResponse , response )
2649
2597
2650
- tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
2598
+ if response ["choices" ][0 ]["text" ] == "</done>" :
2599
+ break
2600
+ tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :].replace (":" , "" )
2651
2601
tool = next (
2652
2602
(tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None
2653
2603
)
@@ -2710,3 +2660,196 @@ def chatml_function_calling(
2710
2660
}
2711
2661
2712
2662
raise ValueError ("Automatic streaming tool choice is not supported" )
2663
+
2664
+
2665
+ @register_chat_completion_handler ("chatml-function-calling" )
2666
+ def chatml_function_calling (
2667
+ llama : llama .Llama ,
2668
+ messages : List [llama_types .ChatCompletionRequestMessage ],
2669
+ functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
2670
+ function_call : Optional [llama_types .ChatCompletionRequestFunctionCall ] = None ,
2671
+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
2672
+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
2673
+ temperature : float = 0.2 ,
2674
+ top_p : float = 0.95 ,
2675
+ top_k : int = 40 ,
2676
+ min_p : float = 0.05 ,
2677
+ typical_p : float = 1.0 ,
2678
+ stream : bool = False ,
2679
+ stop : Optional [Union [str , List [str ]]] = [],
2680
+ response_format : Optional [llama_types .ChatCompletionRequestResponseFormat ] = None ,
2681
+ max_tokens : Optional [int ] = None ,
2682
+ presence_penalty : float = 0.0 ,
2683
+ frequency_penalty : float = 0.0 ,
2684
+ repeat_penalty : float = 1.1 ,
2685
+ tfs_z : float = 1.0 ,
2686
+ mirostat_mode : int = 0 ,
2687
+ mirostat_tau : float = 5.0 ,
2688
+ mirostat_eta : float = 0.1 ,
2689
+ model : Optional [str ] = None ,
2690
+ logits_processor : Optional [llama .LogitsProcessorList ] = None ,
2691
+ grammar : Optional [llama .LlamaGrammar ] = None ,
2692
+ logprobs : Optional [bool ] = None ,
2693
+ top_logprobs : Optional [int ] = None ,
2694
+ ** kwargs , # type: ignore
2695
+ ) -> Union [
2696
+ llama_types .CreateChatCompletionResponse ,
2697
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
2698
+ ]:
2699
+ function_calling_template = (
2700
+ "{% for message in messages %}"
2701
+ "<|im_start|>{{ message.role }}\n "
2702
+ # System message
2703
+ "{% if message.role == 'system' %}"
2704
+ "{{ message.content }}"
2705
+ "{% if tool_calls %}"
2706
+ "\n \n You have access to the following functions:\n "
2707
+ "{% for tool in tools %}"
2708
+ "\n functions.{{ tool.function.name }}:\n "
2709
+ "{{ tool.function.parameters | tojson }}"
2710
+ "\n {% endfor %}"
2711
+ "\n \n You can respond to users messages with either a single message or one or more function calls."
2712
+ "\n \n To respond with a message begin the message with 'message:', use the following format:"
2713
+ "\n \n message:"
2714
+ "\n <message>"
2715
+ "\n \n To respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2716
+ "\n \n functions.<function_name>:"
2717
+ '\n { "arg1": "value1", "arg2": "value2" };'
2718
+ "\n functions.<function_name>:"
2719
+ '\n { "arg1": "value1", "arg2": "value2" }'
2720
+ "\n \n When you are done with the function calls, end the message with </done>."
2721
+ "{% endif %}"
2722
+ "<|im_end|>\n "
2723
+ "{% endif %}"
2724
+ # User message
2725
+ "{% if message.role == 'user' %}"
2726
+ "{{ message.content }}"
2727
+ "<|im_end|>\n "
2728
+ "{% endif %}"
2729
+ # Assistant message
2730
+ "{% if message.role == 'assistant' %}"
2731
+ ## Reglar message
2732
+ "{% if message.content and message.content | length > 0 %}"
2733
+ "{% if tool_calls %}"
2734
+ "message:\n "
2735
+ "{% endif %}"
2736
+ "{{ message.content }}"
2737
+ "<|im_end|>\n "
2738
+ "{% endif %}"
2739
+ ## Function calls
2740
+ "{% if tool_calls %}"
2741
+ "{% for tool_call in message.tool_calls %}"
2742
+ "functions.{{ tool_call.function.name }}:\n "
2743
+ "{{ (tool_call.arguments | default('{}') | tojson) }}"
2744
+ "{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call
2745
+ "{% endfor %}"
2746
+ "<|im_end|>\n "
2747
+ "{% endif %}"
2748
+ "{% endif %}"
2749
+ # Tool message (treated as Assistant response)
2750
+ "{% if message.role == 'tool' %}"
2751
+ "ASSISTANT:\n "
2752
+ "Function response: {{ message.content | default('No response available') }}"
2753
+ "<|im_end|>\n "
2754
+ "{% endif %}"
2755
+ "{% endfor %}"
2756
+ "{% if add_generation_prompt %}<|im_start|>assistant\n {% endif %}"
2757
+ )
2758
+ return base_function_calling (end_token = "<|im_end|>" ,
2759
+ ** locals ())
2760
+
2761
+ @register_chat_completion_handler ("vicuna-function-calling" )
2762
+ def vicuna_function_calling (
2763
+ llama : llama .Llama ,
2764
+ messages : List [llama_types .ChatCompletionRequestMessage ],
2765
+ functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
2766
+ function_call : Optional [llama_types .ChatCompletionRequestFunctionCall ] = None ,
2767
+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
2768
+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
2769
+ temperature : float = 0.2 ,
2770
+ top_p : float = 0.95 ,
2771
+ top_k : int = 40 ,
2772
+ min_p : float = 0.05 ,
2773
+ typical_p : float = 1.0 ,
2774
+ stream : bool = False ,
2775
+ stop : Optional [Union [str , List [str ]]] = [],
2776
+ response_format : Optional [llama_types .ChatCompletionRequestResponseFormat ] = None ,
2777
+ max_tokens : Optional [int ] = None ,
2778
+ presence_penalty : float = 0.0 ,
2779
+ frequency_penalty : float = 0.0 ,
2780
+ repeat_penalty : float = 1.1 ,
2781
+ tfs_z : float = 1.0 ,
2782
+ mirostat_mode : int = 0 ,
2783
+ mirostat_tau : float = 5.0 ,
2784
+ mirostat_eta : float = 0.1 ,
2785
+ model : Optional [str ] = None ,
2786
+ logits_processor : Optional [llama .LogitsProcessorList ] = None ,
2787
+ grammar : Optional [llama .LlamaGrammar ] = None ,
2788
+ logprobs : Optional [bool ] = None ,
2789
+ top_logprobs : Optional [int ] = None ,
2790
+ ** kwargs , # type: ignore
2791
+ ) -> Union [
2792
+ llama_types .CreateChatCompletionResponse ,
2793
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
2794
+ ]:
2795
+ function_calling_template = (
2796
+ "{% for message in messages %}"
2797
+ "{{ message.role.upper() }}\n " # Vicuna uses upper case for roles
2798
+ # System message
2799
+ "{% if message.role == 'system' %}"
2800
+ "{{ message.content }}"
2801
+ "{% if tool_calls %}"
2802
+ "\n \n You have access to the following functions:\n "
2803
+ "{% for tool in tools %}"
2804
+ "\n functions.{{ tool.function.name }}:\n "
2805
+ "{{ tool.function.parameters | tojson }}"
2806
+ "\n {% endfor %}"
2807
+ "\n \n You can respond to users messages with either a single message or multiple function calls."
2808
+ "\n \n To respond with a message begin the message with 'message:', use the following format:"
2809
+ "\n \n message:"
2810
+ "\n <message>"
2811
+ "\n \n To respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2812
+ "\n \n functions.<function_name>:"
2813
+ '\n { "arg1": "value1", "arg2": "value2" };'
2814
+ "\n functions.<another_function_name>:"
2815
+ '\n { "arg1": "value3", "arg2": "value4" }'
2816
+ "\n \n When you are done with the function calls, end the message with </done>."
2817
+ "{% endif %}"
2818
+ "</s>\n "
2819
+ "{% endif %}"
2820
+ # User message
2821
+ "{% if message.role == 'user' %}"
2822
+ "{{ message.content }}"
2823
+ "</s>\n "
2824
+ "{% endif %}"
2825
+ # Assistant message
2826
+ "{% if message.role == 'assistant' %}"
2827
+ ## Regular message
2828
+ "{% if message.content and message.content | length > 0 %}"
2829
+ "{% if tool_calls %}"
2830
+ "message:\n "
2831
+ "{% endif %}"
2832
+ "{{ message.content }}"
2833
+ "</s>\n "
2834
+ "{% endif %}"
2835
+ ## Function calls
2836
+ "{% if tool_calls %}"
2837
+ "{% for tool_call in message.tool_calls %}"
2838
+ "functions.{{ tool_call.function.name }}:\n "
2839
+ "{{ (tool_call.arguments | default('{}') | tojson) }}"
2840
+ "{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call
2841
+ "{% endfor %}"
2842
+ "</s>\n "
2843
+ "{% endif %}"
2844
+ "{% endif %}"
2845
+ # Tool message (treated as Assistant response)
2846
+ "{% if message.role == 'tool' %}"
2847
+ "ASSISTANT:\n "
2848
+ "Function response: {{ message.content | default('No response available') }}"
2849
+ "</s>\n "
2850
+ "{% endif %}"
2851
+ "{% endfor %}"
2852
+ "{% if add_generation_prompt %}</s>ASSISTANT\n {% endif %}" # Vicuna adds the role for prompt continuation
2853
+ )
2854
+ return base_function_calling (end_token = "</s>" ,
2855
+ ** locals ())
0 commit comments