Skip to content

Commit 05a4324

Browse files
maxdebayseriGmainC
andauthored
Initialize the delta tool call fields explicitly (#17340)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: igmainc <igmainc@icloud.com>
1 parent 7ea6cb2 commit 05a4324

12 files changed

+51
-34
lines changed

tests/entrypoints/openai/tool_parsers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def append_delta(self, delta: DeltaMessage):
3232
assert len(delta.tool_calls) < 2, (
3333
"Streaming should include only one tool call per update.")
3434
for call_delta in delta.tool_calls:
35-
assert call_delta.type == "function", (
35+
assert call_delta.type is None or call_delta.type == "function", (
3636
"Streaming tool calls should only emit function calls. Got "
3737
f"{call_delta.type}")
3838
current_tool_call = self.tool_calls[

vllm/entrypoints/chat_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
# yapf: enable
4545
from vllm.transformers_utils.processor import cached_get_processor
4646
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
47+
from vllm.utils import random_uuid
4748

4849
logger = init_logger(__name__)
4950

@@ -1272,3 +1273,6 @@ def apply_mistral_chat_template(
12721273
"An error occurred in `mistral_common` while applying chat "
12731274
"template")
12741275
raise ValueError from e
1276+
1277+
def random_tool_call_id() -> str:
1278+
return f"chatcmpl-tool-{random_uuid()}"

vllm/entrypoints/openai/protocol.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from typing_extensions import TypeAlias
1616

1717
from vllm import envs
18-
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
18+
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19+
random_tool_call_id)
1920
from vllm.logger import init_logger
2021
from vllm.pooling_params import PoolingParams
2122
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@@ -1339,7 +1340,7 @@ class FunctionCall(OpenAIBaseModel):
13391340

13401341

13411342
class ToolCall(OpenAIBaseModel):
1342-
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1343+
id: str = Field(default_factory=random_tool_call_id)
13431344
type: Literal["function"] = "function"
13441345
function: FunctionCall
13451346

@@ -1351,8 +1352,8 @@ class DeltaFunctionCall(BaseModel):
13511352

13521353
# a tool call delta where everything is optional
13531354
class DeltaToolCall(OpenAIBaseModel):
1354-
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1355-
type: Literal["function"] = "function"
1355+
id: Optional[str] = None
1356+
type: Optional[Literal["function"]] = None
13561357
index: int
13571358
function: Optional[DeltaFunctionCall] = None
13581359

vllm/entrypoints/openai/serving_chat.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from vllm.config import ModelConfig
1717
from vllm.engine.protocol import EngineClient
1818
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
19-
ConversationMessage)
19+
ConversationMessage,
20+
random_tool_call_id)
2021
from vllm.entrypoints.logger import RequestLogger
2122
from vllm.entrypoints.openai.protocol import (
2223
ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -363,9 +364,10 @@ def extract_tool_call_required_streaming(
363364

364365
function_name_returned = True
365366
delta_message = DeltaMessage(tool_calls=[
366-
DeltaToolCall(function=DeltaFunctionCall(
367-
name=current_tool_call["name"],
368-
arguments=arguments),
367+
DeltaToolCall(id=random_tool_call_id(),
368+
function=DeltaFunctionCall(
369+
name=current_tool_call["name"],
370+
arguments=arguments),
369371
index=len(obj) - 1,
370372
type="function")
371373
])
@@ -382,8 +384,7 @@ def extract_tool_call_required_streaming(
382384
# instead of name every time
383385
name=None,
384386
arguments=delta_text),
385-
index=len(obj) - 1,
386-
type="function")
387+
index=len(obj) - 1)
387388
])
388389
else:
389390
delta_message = None
@@ -422,7 +423,7 @@ async def chat_completion_stream_generator(
422423
and self._should_stream_with_auto_tool_parsing(request))
423424

424425
all_previous_token_ids: Optional[list[list[int]]]
425-
function_name_returned: Optional[list[bool]] = None
426+
function_name_returned = [False] * num_choices
426427

427428
# Only one of these will be used, thus previous_texts and
428429
# all_previous_token_ids will not be used twice in the same iteration.
@@ -435,7 +436,6 @@ async def chat_completion_stream_generator(
435436
reasoning_end_arr = [False] * num_choices
436437
elif request.tool_choice == "required":
437438
previous_texts = [""] * num_choices
438-
function_name_returned = [False] * num_choices
439439
all_previous_token_ids = None
440440
else:
441441
previous_texts, all_previous_token_ids = None, None
@@ -623,16 +623,27 @@ async def chat_completion_stream_generator(
623623
delta_text = previous_text + delta_text
624624
current_text = ""
625625

626+
if function_name_returned[i]:
627+
delta_tool_call = DeltaToolCall(
628+
function=DeltaFunctionCall(
629+
arguments=delta_text),
630+
index=i)
631+
else:
632+
delta_tool_call = DeltaToolCall(
633+
id=random_tool_call_id(),
634+
type="function",
635+
function=DeltaFunctionCall(
636+
name=tool_choice_function_name,
637+
arguments=delta_text),
638+
index=i)
639+
function_name_returned[i] = True
640+
626641
delta_message = DeltaMessage(tool_calls=[
627-
DeltaToolCall(function=DeltaFunctionCall(
628-
name=tool_choice_function_name,
629-
arguments=delta_text),
630-
index=i)
642+
delta_tool_call,
631643
])
632644

633645
elif request.tool_choice == "required":
634646
assert previous_texts is not None
635-
assert function_name_returned is not None
636647
previous_text = previous_texts[i]
637648
current_text = previous_text + delta_text
638649
fn_name_returned = function_name_returned[i]
@@ -835,7 +846,7 @@ async def chat_completion_stream_generator(
835846
total_tokens=num_prompt_tokens + completion_tokens,
836847
)
837848

838-
data = chunk.model_dump_json(exclude_unset=True)
849+
data = chunk.model_dump_json(exclude_none=True)
839850
yield f"data: {data}\n\n"
840851

841852
# once the final token is handled, if stream_options.include_usage

vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import partial_json_parser
1010
from partial_json_parser.core.options import Allow
1111

12+
from vllm.entrypoints.chat_utils import random_tool_call_id
1213
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1314
DeltaFunctionCall, DeltaMessage,
1415
DeltaToolCall,
@@ -22,7 +23,6 @@
2223
partial_json_loads)
2324
from vllm.logger import init_logger
2425
from vllm.transformers_utils.tokenizer import AnyTokenizer
25-
from vllm.utils import random_uuid
2626

2727
logger = init_logger(__name__)
2828

@@ -200,7 +200,7 @@ def extract_tool_calls_streaming(
200200
delta = DeltaMessage(tool_calls=[
201201
DeltaToolCall(index=self.current_tool_id,
202202
type="function",
203-
id=f"chatcmpl-tool-{random_uuid()}",
203+
id=random_tool_call_id(),
204204
function=DeltaFunctionCall(
205205
name=function_name).model_dump(
206206
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import partial_json_parser
88
from partial_json_parser.core.options import Allow
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaFunctionCall, DeltaMessage,
1213
DeltaToolCall,
@@ -20,7 +21,6 @@
2021
partial_json_loads)
2122
from vllm.logger import init_logger
2223
from vllm.transformers_utils.tokenizer import AnyTokenizer
23-
from vllm.utils import random_uuid
2424

2525
logger = init_logger(__name__)
2626

@@ -182,7 +182,7 @@ def extract_tool_calls_streaming(
182182
delta = DeltaMessage(tool_calls=[
183183
DeltaToolCall(index=self.current_tool_id,
184184
type="function",
185-
id=f"chatcmpl-tool-{random_uuid()}",
185+
id=random_tool_call_id(),
186186
function=DeltaFunctionCall(
187187
name=function_name).model_dump(
188188
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import partial_json_parser
99
from partial_json_parser.core.options import Allow
1010

11+
from vllm.entrypoints.chat_utils import random_tool_call_id
1112
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1213
DeltaFunctionCall, DeltaMessage,
1314
DeltaToolCall,
@@ -17,7 +18,6 @@
1718
ToolParser, ToolParserManager)
1819
from vllm.logger import init_logger
1920
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
20-
from vllm.utils import random_uuid
2121

2222
logger = init_logger(__name__)
2323

@@ -259,7 +259,7 @@ def extract_tool_calls_streaming(
259259
return DeltaMessage(tool_calls=[
260260
DeltaToolCall(index=self.current_tool_id,
261261
type="function",
262-
id=f"chatcmpl-tool-{random_uuid()}",
262+
id=random_tool_call_id(),
263263
function=DeltaFunctionCall(
264264
name=function_name).model_dump(
265265
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import partial_json_parser
88
from partial_json_parser.core.options import Allow
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaFunctionCall, DeltaMessage,
1213
DeltaToolCall,
@@ -18,7 +19,6 @@
1819
extract_intermediate_diff)
1920
from vllm.logger import init_logger
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
21-
from vllm.utils import random_uuid
2222

2323
logger = init_logger(__name__)
2424

@@ -106,7 +106,7 @@ def extract_tool_calls_streaming(
106106
delta = DeltaMessage(tool_calls=[
107107
DeltaToolCall(index=self.current_tool_id,
108108
type="function",
109-
id=f"chatcmpl-tool-{random_uuid()}",
109+
id=random_tool_call_id(),
110110
function=DeltaFunctionCall(
111111
name=function_name).model_dump(
112112
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import partial_json_parser
99
from partial_json_parser.core.options import Allow
1010

11+
from vllm.entrypoints.chat_utils import random_tool_call_id
1112
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1213
DeltaFunctionCall, DeltaMessage,
1314
DeltaToolCall,
@@ -19,7 +20,6 @@
1920
from vllm.logger import init_logger
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
2122
from vllm.transformers_utils.tokenizers import MistralTokenizer
22-
from vllm.utils import random_uuid
2323

2424
logger = init_logger(__name__)
2525

@@ -220,7 +220,7 @@ def extract_tool_calls_streaming(
220220
delta = DeltaMessage(tool_calls=[
221221
DeltaToolCall(index=self.current_tool_id,
222222
type="function",
223-
id=f"chatcmpl-tool-{random_uuid()}",
223+
id=random_tool_call_id(),
224224
function=DeltaFunctionCall(
225225
name=function_name).model_dump(
226226
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from partial_json_parser.core.options import Allow
1111
from transformers import PreTrainedTokenizerBase
1212

13+
from vllm.entrypoints.chat_utils import random_tool_call_id
1314
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1415
DeltaFunctionCall, DeltaMessage,
1516
DeltaToolCall,
@@ -21,7 +22,6 @@
2122
is_complete_json,
2223
partial_json_loads)
2324
from vllm.logger import init_logger
24-
from vllm.utils import random_uuid
2525

2626
logger = init_logger(__name__)
2727

@@ -208,7 +208,7 @@ def extract_tool_calls_streaming(
208208
delta = DeltaMessage(tool_calls=[
209209
DeltaToolCall(index=self.current_tool_id,
210210
type="function",
211-
id=f"chatcmpl-tool-{random_uuid()}",
211+
id=random_tool_call_id(),
212212
function=DeltaFunctionCall(
213213
name=function_name).model_dump(
214214
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
from transformers import PreTrainedTokenizerBase
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaMessage,
1213
ExtractedToolCallInformation,
1314
FunctionCall, ToolCall)
1415
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
1516
ToolParser, ToolParserManager)
1617
from vllm.logger import init_logger
17-
from vllm.utils import random_uuid
1818

1919
logger = init_logger(__name__)
2020

@@ -73,7 +73,7 @@ def extract_tool_calls(
7373

7474
tool_calls: list[ToolCall] = [
7575
ToolCall(
76-
id=f"chatcmpl-tool-{random_uuid()}",
76+
id=random_tool_call_id(),
7777
type="function",
7878
function=FunctionCall(
7979
name=raw_function_call["name"],

vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
280280
new_call_args = new_call_args[:-len(withheld_suffix)]
281281
if not previously_sent_args:
282282
return DeltaToolCall(id=new_call.id,
283+
type="function",
283284
index=index,
284285
function=DeltaFunctionCall(
285286
name=new_call.function.name,
@@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
288289

289290
arg_diff = new_call_args[len(previously_sent_args):]
290291
return DeltaToolCall(
291-
id="", index=index, function=DeltaFunctionCall(
292+
id=None, index=index, function=DeltaFunctionCall(
292293
arguments=arg_diff)) if arg_diff else None

0 commit comments

Comments
 (0)