@@ -1828,27 +1828,35 @@ def prepare_messages_for_inference(
1828
1828
version : Literal ["v1" , "v2" ],
1829
1829
functions : Optional [List [llama_types .ChatCompletionFunctions ]] = None ,
1830
1830
tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
1831
+ tool_choice : Union [Dict , str ] = "auto" ,
1831
1832
):
1832
1833
all_messages : List [llama_types .ChatCompletionRequestMessage ] = []
1833
- if functions is not None :
1834
+ if tool_choice == "none" :
1834
1835
all_messages .append (
1835
1836
llama_types .ChatCompletionRequestSystemMessage (
1836
- role = "system" , content = generate_schema_from_functions (functions )
1837
+ role = "system" , content = generate_schema_from_functions ([] )
1837
1838
)
1838
1839
)
1839
- elif tools is not None :
1840
- all_messages .append (
1841
- llama_types .ChatCompletionRequestSystemMessage (
1842
- role = "system" ,
1843
- content = generate_schema_from_functions (
1844
- [
1845
- tool ["function" ]
1846
- for tool in tools
1847
- if tool ["type" ] == "function"
1848
- ]
1849
- ),
1840
+ else :
1841
+ if functions is not None :
1842
+ all_messages .append (
1843
+ llama_types .ChatCompletionRequestSystemMessage (
1844
+ role = "system" , content = generate_schema_from_functions (functions )
1845
+ )
1846
+ )
1847
+ elif tools is not None and tool_choice != "none" :
1848
+ all_messages .append (
1849
+ llama_types .ChatCompletionRequestSystemMessage (
1850
+ role = "system" ,
1851
+ content = generate_schema_from_functions (
1852
+ [
1853
+ tool ["function" ]
1854
+ for tool in tools
1855
+ if tool ["type" ] == "function"
1856
+ ]
1857
+ ),
1858
+ )
1850
1859
)
1851
- )
1852
1860
1853
1861
all_messages .append (
1854
1862
llama_types .ChatCompletionRequestSystemMessage (
@@ -1888,7 +1896,7 @@ def prepare_messages_for_inference(
1888
1896
function_call = "auto"
1889
1897
1890
1898
prompt = prepare_messages_for_inference (
1891
- messages , tokenizer , version , functions , tools
1899
+ messages , tokenizer , version , functions , tools , function_call
1892
1900
)
1893
1901
1894
1902
# If no tools/functions are provided
@@ -1985,17 +1993,12 @@ def create_completion(stop):
1985
1993
1986
1994
content = ""
1987
1995
function_calls , function_bodies = [], []
1996
+ completion_tokens = 0
1988
1997
1989
1998
if version == "v1" :
1990
1999
# If no or "auto" tool_choice/function_call
1991
2000
if isinstance (function_call , str ) and function_call == "auto" :
1992
2001
stops = ["\n " , END_ASSISTANT_TOKEN ]
1993
- # If tool_choice/function_call is "none"
1994
- elif isinstance (function_call , str ) and function_call == "none" :
1995
- prompt = prepare_messages_for_inference (
1996
- messages , tokenizer , version , [], []
1997
- )
1998
- stops = END_ASSISTANT_TOKEN
1999
2002
# If tool_choice/function_call is provided
2000
2003
elif isinstance (function_call , dict ):
2001
2004
prompt += f"{ START_FUNCTION_CALL_TOKEN } { function_call ['name' ]} :\n "
@@ -2009,12 +2012,15 @@ def create_completion(stop):
2009
2012
2010
2013
completion = create_completion (stop = stops )
2011
2014
completion_text = completion ["choices" ][0 ]["text" ]
2015
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2016
+
2012
2017
2013
2018
# If the generation does not involve a function call
2014
2019
if (
2015
2020
START_FUNCTION_CALL_TOKEN not in prompt
2016
2021
and START_FUNCTION_CALL_TOKEN not in completion_text
2017
2022
):
2023
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2018
2024
return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2019
2025
# If the generation involves a function call in completion, generate the parameters
2020
2026
elif (
@@ -2032,30 +2038,22 @@ def create_completion(stop):
2032
2038
)
2033
2039
grammar = get_grammar (function_calls [- 1 ])
2034
2040
completion = create_completion (stop = END_FUNCTION_CALL_TOKEN )
2041
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2035
2042
function_bodies .append (completion ["choices" ][0 ]["text" ].strip ())
2036
2043
# If the prompt involves a function call, just append generated parameters to function_bodies
2037
2044
else :
2038
2045
function_bodies .append (completion_text .strip ())
2039
2046
else :
2040
- # If tool_choice/function_call is "none"
2041
- if isinstance (function_call , str ) and function_call == "none" :
2042
- prompt = (
2043
- prepare_messages_for_inference (messages , tokenizer , version , [], [])
2044
- + "all\n <|content|>"
2045
- )
2046
- stops = [STOP_TOKEN , FROM_TOKEN ]
2047
- completion = create_completion (stop = stops )
2048
- completion ["choices" ][0 ]["text" ] = completion ["choices" ][0 ]["text" ].strip ()
2049
- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2050
2047
# If tool_choice/function_call is provided
2051
- elif isinstance (function_call , dict ):
2048
+ if isinstance (function_call , dict ):
2052
2049
prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
2053
2050
function_call = function_call ["name" ]
2054
2051
function_calls .append (function_call )
2055
2052
grammar = get_grammar (function_call )
2056
2053
stops = [STOP_TOKEN , FROM_TOKEN ]
2057
2054
completion = create_completion (stop = stops )
2058
2055
completion_text = completion ["choices" ][0 ]["text" ]
2056
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2059
2057
function_bodies .append (completion_text .strip ())
2060
2058
# If "auto" or no tool_choice/function_call
2061
2059
elif isinstance (function_call , str ) and function_call == "auto" :
@@ -2065,6 +2063,7 @@ def create_completion(stop):
2065
2063
stops = CONTENT_TOKEN
2066
2064
completion = create_completion (stop = stops )
2067
2065
completion_text = completion ["choices" ][0 ]["text" ]
2066
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2068
2067
function_name = completion_text .strip ()
2069
2068
if function_name == "all" :
2070
2069
prompt += "all\n <|content|>"
@@ -2077,6 +2076,7 @@ def create_completion(stop):
2077
2076
stops = [RECIPIENT_TOKEN , STOP_TOKEN ]
2078
2077
completion = create_completion (stop = stops )
2079
2078
completion_text = completion ["choices" ][0 ]["text" ]
2079
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2080
2080
if function_name == "all" :
2081
2081
content += completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " )
2082
2082
content = content .lstrip ()
@@ -2092,6 +2092,7 @@ def create_completion(stop):
2092
2092
prompt += completion_text .strip ()
2093
2093
grammar = None
2094
2094
completion = create_completion (stop = stops )
2095
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2095
2096
if "<|from|> assistant" in completion ["choices" ][0 ]["text" ] or "<|from|>assistant" in completion ["choices" ][0 ]["text" ]:
2096
2097
prompt += "\n <|from|>assistant\n <|recipient|>"
2097
2098
else :
@@ -2126,6 +2127,7 @@ def create_completion(stop):
2126
2127
"arguments" : tool_calls [0 ]["function" ]["arguments" ],
2127
2128
}
2128
2129
} if len (tool_calls ) == 1 else {}
2130
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2129
2131
return llama_types .CreateChatCompletionResponse (
2130
2132
id = "chat" + completion ["id" ],
2131
2133
object = "chat.completion" ,
0 commit comments