Skip to content

Commit 83caf35

Browse files
authored
[BugFix] Enforce Mistral ToolCall id constraint when using the Mistral tool call parser (#9020)
1 parent 01843c8 commit 83caf35

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

tests/tool_use/test_parallel_tool_calls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
4545
assert tool_call.type == "function"
4646
assert tool_call.function is not None
4747
assert isinstance(tool_call.id, str)
48-
assert len(tool_call.id) > 16
48+
assert len(tool_call.id) >= 9
4949

5050
# make sure the weather tool was called correctly
5151
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
@@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
108108
if tool_call.id:
109109
tool_call_id_count += 1
110110
assert (isinstance(tool_call.id, str)
111-
and (len(tool_call.id) > 16))
111+
and (len(tool_call.id) >= 9))
112112

113113
# if parts of the function start being streamed
114114
if tool_call.function:

tests/tool_use/test_tool_calls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
3333
assert tool_calls[0].type == 'function'
3434
assert tool_calls[0].function is not None
3535
assert isinstance(tool_calls[0].id, str)
36-
assert len(tool_calls[0].id) > 16
36+
assert len(tool_calls[0].id) >= 9
3737

3838
# make sure the weather tool was called (classic example) with arguments
3939
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
@@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
106106

107107
assert finish_reason_count == 1
108108
assert role_name == 'assistant'
109-
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
109+
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
110110

111111
# validate the name and arguments
112112
assert function_name == WEATHER_TOOL["function"]["name"]

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
22
import re
3+
from random import choices
4+
from string import ascii_letters, digits
35
from typing import Dict, List, Sequence, Union
46

57
import partial_json_parser
68
from partial_json_parser.core.options import Allow
9+
from pydantic import Field
710

811
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
912
DeltaToolCall,
@@ -19,6 +22,19 @@
1922

2023
logger = init_logger(__name__)
2124

25+
ALPHANUMERIC = ascii_letters + digits
26+
27+
28+
class MistralToolCall(ToolCall):
29+
id: str = Field(
30+
default_factory=lambda: MistralToolCall.generate_random_id())
31+
32+
@staticmethod
33+
def generate_random_id():
34+
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
35+
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
36+
return "".join(choices(ALPHANUMERIC, k=9))
37+
2238

2339
class MistralToolParser(ToolParser):
2440
"""
@@ -71,8 +87,8 @@ def extract_tool_calls(self,
7187
# load the JSON, and then use it to build the Function and
7288
# Tool Call
7389
function_call_arr = json.loads(raw_tool_call)
74-
tool_calls: List[ToolCall] = [
75-
ToolCall(
90+
tool_calls: List[MistralToolCall] = [
91+
MistralToolCall(
7692
type="function",
7793
function=FunctionCall(
7894
name=raw_function_call["name"],

0 commit comments

Comments
 (0)