Skip to content

Commit 84d0874

Browse files
committed
fix completion tokens tracking, prompt forming
1 parent 266abfc commit 84d0874

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,27 +1828,35 @@ def prepare_messages_for_inference(
18281828
version: Literal["v1", "v2"],
18291829
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
18301830
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
1831+
tool_choice: Union[Dict, str] = "auto",
18311832
):
18321833
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
1833-
if functions is not None:
1834+
if tool_choice == "none":
18341835
all_messages.append(
18351836
llama_types.ChatCompletionRequestSystemMessage(
1836-
role="system", content=generate_schema_from_functions(functions)
1837+
role="system", content=generate_schema_from_functions([])
18371838
)
18381839
)
1839-
elif tools is not None:
1840-
all_messages.append(
1841-
llama_types.ChatCompletionRequestSystemMessage(
1842-
role="system",
1843-
content=generate_schema_from_functions(
1844-
[
1845-
tool["function"]
1846-
for tool in tools
1847-
if tool["type"] == "function"
1848-
]
1849-
),
1840+
else:
1841+
if functions is not None:
1842+
all_messages.append(
1843+
llama_types.ChatCompletionRequestSystemMessage(
1844+
role="system", content=generate_schema_from_functions(functions)
1845+
)
1846+
)
1847+
elif tools is not None and tool_choice != "none":
1848+
all_messages.append(
1849+
llama_types.ChatCompletionRequestSystemMessage(
1850+
role="system",
1851+
content=generate_schema_from_functions(
1852+
[
1853+
tool["function"]
1854+
for tool in tools
1855+
if tool["type"] == "function"
1856+
]
1857+
),
1858+
)
18501859
)
1851-
)
18521860

18531861
all_messages.append(
18541862
llama_types.ChatCompletionRequestSystemMessage(
@@ -1888,7 +1896,7 @@ def prepare_messages_for_inference(
18881896
function_call = "auto"
18891897

18901898
prompt = prepare_messages_for_inference(
1891-
messages, tokenizer, version, functions, tools
1899+
messages, tokenizer, version, functions, tools, function_call
18921900
)
18931901

18941902
# If no tools/functions are provided
@@ -1985,17 +1993,12 @@ def create_completion(stop):
19851993

19861994
content = ""
19871995
function_calls, function_bodies = [], []
1996+
completion_tokens = 0
19881997

19891998
if version == "v1":
19901999
# If no or "auto" tool_choice/function_call
19912000
if isinstance(function_call, str) and function_call == "auto":
19922001
stops = ["\n", END_ASSISTANT_TOKEN]
1993-
# If tool_choice/function_call is "none"
1994-
elif isinstance(function_call, str) and function_call == "none":
1995-
prompt = prepare_messages_for_inference(
1996-
messages, tokenizer, version, [], []
1997-
)
1998-
stops = END_ASSISTANT_TOKEN
19992002
# If tool_choice/function_call is provided
20002003
elif isinstance(function_call, dict):
20012004
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
@@ -2009,12 +2012,15 @@ def create_completion(stop):
20092012

20102013
completion = create_completion(stop=stops)
20112014
completion_text = completion["choices"][0]["text"]
2015+
completion_tokens += completion["usage"]["completion_tokens"]
2016+
20122017

20132018
# If the generation does not involve a function call
20142019
if (
20152020
START_FUNCTION_CALL_TOKEN not in prompt
20162021
and START_FUNCTION_CALL_TOKEN not in completion_text
20172022
):
2023+
completion["usage"]["completion_tokens"] = completion_tokens
20182024
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
20192025
# If the generation involves a function call in completion, generate the parameters
20202026
elif (
@@ -2032,30 +2038,22 @@ def create_completion(stop):
20322038
)
20332039
grammar = get_grammar(function_calls[-1])
20342040
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
2041+
completion_tokens += completion["usage"]["completion_tokens"]
20352042
function_bodies.append(completion["choices"][0]["text"].strip())
20362043
# If the prompt involves a function call, just append generated parameters to function_bodies
20372044
else:
20382045
function_bodies.append(completion_text.strip())
20392046
else:
2040-
# If tool_choice/function_call is "none"
2041-
if isinstance(function_call, str) and function_call == "none":
2042-
prompt = (
2043-
prepare_messages_for_inference(messages, tokenizer, version, [], [])
2044-
+ "all\n<|content|>"
2045-
)
2046-
stops = [STOP_TOKEN, FROM_TOKEN]
2047-
completion = create_completion(stop=stops)
2048-
completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
2049-
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
20502047
# If tool_choice/function_call is provided
2051-
elif isinstance(function_call, dict):
2048+
if isinstance(function_call, dict):
20522049
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
20532050
function_call = function_call["name"]
20542051
function_calls.append(function_call)
20552052
grammar = get_grammar(function_call)
20562053
stops = [STOP_TOKEN, FROM_TOKEN]
20572054
completion = create_completion(stop=stops)
20582055
completion_text = completion["choices"][0]["text"]
2056+
completion_tokens += completion["usage"]["completion_tokens"]
20592057
function_bodies.append(completion_text.strip())
20602058
# If "auto" or no tool_choice/function_call
20612059
elif isinstance(function_call, str) and function_call == "auto":
@@ -2065,6 +2063,7 @@ def create_completion(stop):
20652063
stops = CONTENT_TOKEN
20662064
completion = create_completion(stop=stops)
20672065
completion_text = completion["choices"][0]["text"]
2066+
completion_tokens += completion["usage"]["completion_tokens"]
20682067
function_name = completion_text.strip()
20692068
if function_name == "all":
20702069
prompt += "all\n<|content|>"
@@ -2077,6 +2076,7 @@ def create_completion(stop):
20772076
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
20782077
completion = create_completion(stop=stops)
20792078
completion_text = completion["choices"][0]["text"]
2079+
completion_tokens += completion["usage"]["completion_tokens"]
20802080
if function_name == "all":
20812081
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
20822082
content = content.lstrip()
@@ -2092,6 +2092,7 @@ def create_completion(stop):
20922092
prompt += completion_text.strip()
20932093
grammar = None
20942094
completion = create_completion(stop=stops)
2095+
completion_tokens += completion["usage"]["completion_tokens"]
20952096
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
20962097
prompt += "\n<|from|>assistant\n<|recipient|>"
20972098
else:
@@ -2126,6 +2127,7 @@ def create_completion(stop):
21262127
"arguments": tool_calls[0]["function"]["arguments"],
21272128
}
21282129
} if len(tool_calls) == 1 else {}
2130+
completion["usage"]["completion_tokens"] = completion_tokens
21292131
return llama_types.CreateChatCompletionResponse(
21302132
id="chat" + completion["id"],
21312133
object="chat.completion",

0 commit comments

Comments
 (0)