5
5
6
6
import partial_json_parser
7
7
from partial_json_parser .core .options import Allow
8
+ from transformers import PreTrainedTokenizerBase
8
9
9
10
from vllm .entrypoints .openai .protocol import (DeltaFunctionCall , DeltaMessage ,
10
11
DeltaToolCall ,
14
15
ToolParser )
15
16
from vllm .entrypoints .openai .tool_parsers .utils import find_common_prefix
16
17
from vllm .logger import init_logger
17
- from vllm .transformers_utils .tokenizer import AnyTokenizer
18
18
from vllm .utils import random_uuid
19
19
20
20
logger = init_logger (__name__ )
@@ -49,7 +49,7 @@ class Llama3JsonToolParser(ToolParser):
49
49
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
50
50
"""
51
51
52
- def __init__ (self , tokenizer : AnyTokenizer ):
52
+ def __init__ (self , tokenizer : PreTrainedTokenizerBase ):
53
53
super ().__init__ (tokenizer )
54
54
55
55
# initialize properties used for state when parsing tool calls in
@@ -60,8 +60,8 @@ def __init__(self, tokenizer: AnyTokenizer):
60
60
self .streamed_args_for_tool : List [str ] = [
61
61
] # map what has been streamed for each tool so far to a list
62
62
self .bot_token = "<|python_tag|>"
63
- self .bot_token_id = self . model_tokenizer . encode (
64
- self . bot_token , add_special_tokens = False )[0 ]
63
+ self .bot_token_id = tokenizer . encode (self . bot_token ,
64
+ add_special_tokens = False )[0 ]
65
65
self .tool_call_regex = re .compile (r"\[{.*?}\]" , re .DOTALL )
66
66
67
67
def extract_tool_calls (self ,
0 commit comments