From c817793e05dd7e6f0aa1d3f752e1d5f9af423287 Mon Sep 17 00:00:00 2001 From: lucca Date: Tue, 16 Apr 2024 22:55:44 -0300 Subject: [PATCH 01/11] feat: improve function calling --- llama_cpp/llama_chat_format.py | 287 ++++++++++++---- llama_cpp/llama_grammar.py | 612 +++++++++++++++++++++++++++------ 2 files changed, 728 insertions(+), 171 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 519d2f50a..70efef866 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2286,9 +2286,9 @@ def __call__( stream=stream, ) - -@register_chat_completion_handler("chatml-function-calling") -def chatml_function_calling( +def base_function_calling( + function_calling_template, + end_token, llama: llama.Llama, messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, @@ -2320,65 +2320,13 @@ def chatml_function_calling( ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], -]: - print(logprobs) - function_calling_template = ( - "{% for message in messages %}" - "<|im_start|>{{ message.role }}\n" - # System message - "{% if message.role == 'system' %}" - "{{ message.content }}" - "{% if tool_calls %}" - "\n\nYou have access to the following functions:\n" - "{% for tool in tools %}" - "\nfunctions.{{ tool.function.name }}:\n" - "{{ tool.function.parameters | tojson }}" - "\n{% endfor %}" - "\n\nYou can respond to users messages with either a single message or one or more function calls." - "\n\nTo respond with a message begin the message with 'message:', use the following format:" - "\n\nmessage:" - "\n" - "\n\nTo respond with one or more function calls begin the message with 'functions.:', use the following format:" - "\n\nfunctions.:" - '\n{ "arg1": "value1", "arg2": "value2" }' - "\nfunctions.:" - '\n{ "arg1": "value1", "arg2": "value2" }' - "{% endif %}" - "<|im_end|>\n" - "{% endif %}" - # User message - "{% if message.role == 'user' %}" - "{{ message.content }}" - "<|im_end|>\n" - "{% endif %}" - # Assistant message - "{% if message.role == 'assistant' %}" - ## Reglar message - "{% if message.content and message.content | length > 0 %}" - "{% if tool_calls %}" - "message:\n" - "{% endif %}" - "{{ message.content }}" - "<|im_end|>\n" - "{% endif %}" - ## Function calls - "{% if 'tool_calls' in message %}" - "{% for tool_call in message.tool_calls %}" - "functions.{{ tool_call.function.name }}:\n" - "{{ tool_call.function.arguments }}" - "{% endfor %}" - "<|im_end|>\n" - "{% endif %}" - "{% endif %}" - "{% endfor %}" - "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" - ) +]: + template_renderer = jinja2.Environment( loader=jinja2.BaseLoader(), autoescape=jinja2.select_autoescape(["html", "xml"]), - undefined=jinja2.StrictUndefined, + undefined=jinja2.DebugUndefined, ).from_string(function_calling_template) - # Convert legacy functions to tools if functions is not None: tools = [ @@ -2403,8 +2351,7 @@ def chatml_function_calling( }, } - stop = [stop, "<|im_end|>"] if isinstance(stop, str) else stop + ["<|im_end|>"] if stop else ["<|im_end|>"] - + stop = [stop, end_token] if isinstance(stop, str) else stop + [end_token] if stop else [end_token] # Case 1: No tool choice by user if ( tool_choice is None @@ -2418,7 +2365,6 @@ def chatml_function_calling( tool_calls=None, add_generation_prompt=True, ) - if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format(response_format) @@ -2506,14 +2452,18 @@ def chatml_function_calling( function_names = " | ".join( [f'''"functions.{tool['function']['name']}:"''' for tool in tools] ) + initial_gbnf_tool_grammar = ( """root ::= functions | "message:"\n""" f"""functions ::= {function_names}\n""" ) + follow_up_gbnf_tool_grammar = ( - """root ::= functions | "<|im_end|>"\n""" + f"""root ::= functions | ""\n""" f"""functions ::= {function_names}\n""" ) + + prompt = template_renderer.render( messages=messages, tools=tools, @@ -2522,14 +2472,14 @@ def chatml_function_calling( ) completion_or_chunks = llama.create_completion( prompt=prompt, - temperature=0, + temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, typical_p=typical_p, stream=False, - stop=[":"], - max_tokens=None, + stop=stop, + max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, @@ -2555,7 +2505,7 @@ def chatml_function_calling( min_p=min_p, typical_p=typical_p, stream=stream, - stop=["<|im_end|>"], + stop=[""], logprobs=top_logprobs if logprobs else None, max_tokens=None, presence_penalty=presence_penalty, @@ -2567,15 +2517,12 @@ def chatml_function_calling( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose - ), ), stream=stream, ) # One or more function calls - tool_name = text[len("functions.") :] + tool_name = text[len("functions.") :].replace(":", "") tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) if not stream: completions: List[llama_types.CreateCompletionResponse] = [] @@ -2621,7 +2568,6 @@ def chatml_function_calling( completions_tool_name.append(tool_name) prompt += completion_or_chunks["choices"][0]["text"] prompt += "\n" - response = llama.create_completion( prompt=prompt, temperature=temperature, @@ -2644,10 +2590,14 @@ def chatml_function_calling( grammar=llama_grammar.LlamaGrammar.from_string( follow_up_gbnf_tool_grammar, verbose=llama.verbose ), + ) + response = cast(llama_types.CreateCompletionResponse, response) - tool_name = response["choices"][0]["text"][len("functions.") :] + if response["choices"][0]["text"] == "": + break + tool_name = response["choices"][0]["text"][len("functions.") :].replace(":", "") tool = next( (tool for tool in tools if tool["function"]["name"] == tool_name), None ) @@ -2710,3 +2660,196 @@ def chatml_function_calling( } raise ValueError("Automatic streaming tool choice is not supported") + + +@register_chat_completion_handler("chatml-function-calling") +def chatml_function_calling( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs, # type: ignore +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + function_calling_template = ( + "{% for message in messages %}" + "<|im_start|>{{ message.role }}\n" + # System message + "{% if message.role == 'system' %}" + "{{ message.content }}" + "{% if tool_calls %}" + "\n\nYou have access to the following functions:\n" + "{% for tool in tools %}" + "\nfunctions.{{ tool.function.name }}:\n" + "{{ tool.function.parameters | tojson }}" + "\n{% endfor %}" + "\n\nYou can respond to users messages with either a single message or one or more function calls." + "\n\nTo respond with a message begin the message with 'message:', use the following format:" + "\n\nmessage:" + "\n" + "\n\nTo respond with one or more function calls begin the message with 'functions.:', use the following format:" + "\n\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" };' + "\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" }' + "\n\nWhen you are done with the function calls, end the message with ." + "{% endif %}" + "<|im_end|>\n" + "{% endif %}" + # User message + "{% if message.role == 'user' %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + # Assistant message + "{% if message.role == 'assistant' %}" + ## Reglar message + "{% if message.content and message.content | length > 0 %}" + "{% if tool_calls %}" + "message:\n" + "{% endif %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + ## Function calls + "{% if tool_calls %}" + "{% for tool_call in message.tool_calls %}" + "functions.{{ tool_call.function.name }}:\n" + "{{ (tool_call.arguments | default('{}') | tojson) }}" + "{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call + "{% endfor %}" + "<|im_end|>\n" + "{% endif %}" + "{% endif %}" + # Tool message (treated as Assistant response) + "{% if message.role == 'tool' %}" + "ASSISTANT:\n" + "Function response: {{ message.content | default('No response available') }}" + "<|im_end|>\n" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + ) + return base_function_calling(end_token="<|im_end|>", + **locals()) + +@register_chat_completion_handler("vicuna-function-calling") +def vicuna_function_calling( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs, # type: ignore +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + function_calling_template = ( + "{% for message in messages %}" + "{{ message.role.upper() }}\n" # Vicuna uses upper case for roles + # System message + "{% if message.role == 'system' %}" + "{{ message.content }}" + "{% if tool_calls %}" + "\n\nYou have access to the following functions:\n" + "{% for tool in tools %}" + "\nfunctions.{{ tool.function.name }}:\n" + "{{ tool.function.parameters | tojson }}" + "\n{% endfor %}" + "\n\nYou can respond to users messages with either a single message or multiple function calls." + "\n\nTo respond with a message begin the message with 'message:', use the following format:" + "\n\nmessage:" + "\n" + "\n\nTo respond with one or more function calls begin the message with 'functions.:', use the following format:" + "\n\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" };' + "\nfunctions.:" + '\n{ "arg1": "value3", "arg2": "value4" }' + "\n\nWhen you are done with the function calls, end the message with ." + "{% endif %}" + "\n" + "{% endif %}" + # User message + "{% if message.role == 'user' %}" + "{{ message.content }}" + "\n" + "{% endif %}" + # Assistant message + "{% if message.role == 'assistant' %}" + ## Regular message + "{% if message.content and message.content | length > 0 %}" + "{% if tool_calls %}" + "message:\n" + "{% endif %}" + "{{ message.content }}" + "\n" + "{% endif %}" + ## Function calls + "{% if tool_calls %}" + "{% for tool_call in message.tool_calls %}" + "functions.{{ tool_call.function.name }}:\n" + "{{ (tool_call.arguments | default('{}') | tojson) }}" + "{% if not loop.last %};{% endif %}" # Semicolon separator if not the last function call + "{% endfor %}" + "\n" + "{% endif %}" + "{% endif %}" + # Tool message (treated as Assistant response) + "{% if message.role == 'tool' %}" + "ASSISTANT:\n" + "Function response: {{ message.content | default('No response available') }}" + "\n" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}ASSISTANT\n{% endif %}" # Vicuna adds the role for prompt continuation + ) + return base_function_calling(end_token="", + **locals()) \ No newline at end of file diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 9cc48a93b..8c0f8aa09 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -5,11 +5,12 @@ import sys from ctypes import * # type: ignore from enum import Enum -from itertools import islice +from itertools import islice, groupby from typing import ( Any, Callable, Dict, + Set, Generic, List, Optional, @@ -1391,139 +1392,552 @@ def print_grammar(file: TextIO, state: parse_state) -> None: # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' -PRIMITIVE_RULES = { - "boolean": '("true" | "false") space', - "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', - "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', - "string": r""" "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" space """, - "null": '"null" space', -} INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): + if not separator_rule: + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + elif min_items == 1 and max_items is None: + return f'{item_rule}+' + + result = '' + + if min_items > 0: + if item_rule_is_literal and separator_rule is None: + result = '"' + (item_rule[1:-1] * min_items) + '"' + else: + result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) + + def opt_repetitions(up_to_n, prefix_with_sep=False): + ''' + - n=4, no sep: '(a (a (a (a)?)?)?)?' + - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' + - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' + ''' + + content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule + if up_to_n == 0: + return '' + elif up_to_n == 1: + return f'({content})?' + elif separator_rule and not prefix_with_sep: + return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' + else: + return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) + + if min_items > 0 and max_items != min_items: + result += ' ' + + if max_items is not None: + result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) + else: + item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' + + if min_items == 0 and separator_rule: + result = f'({item_rule} {item_operator}*)?' + else: + result += f'{item_operator}*' + + return result + + + +class BuiltinRule: + def __init__(self, content: str, deps: list = None): + self.content = content + self.deps = deps or [] + +_up_to_15_digits = _build_repetition('[0-9]', 0, 15) + +PRIMITIVE_RULES = { + 'boolean' : BuiltinRule('("true" | "false") space', []), + 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), + 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), + 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), + 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), + 'null' : BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +DOTALL = '[\\U00000000-\\U0010FFFF]' +DOT = '[^\\x0A\\x0D]' + +RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) + + +NON_LITERAL_SET = set('|.()[]{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') + + + class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} self._defs: Dict[str, Any] = {} + self._refs = {} + self._refs_being_resolved = set() - def _format_literal(self, literal: str): - escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal ) return f'"{escaped}"' - def _add_rule(self, name: str, rule: str): - esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: + ''' + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + ''' + assert len(literal) > 0, 'Empty literal not supported' + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == '_': + yield f'[^{c}\\\\]' + yield ' | ' + yield f'"\\\\"? "{c}"' + else: + yield f'[^{c}]' + if i < len(literal) - 1: + yield ' | ' + yield self._format_literal(c) + yield ' (' + yield from recurse(i + 1) + yield ')?' + + return ''.join(('(', *recurse(0), ')')) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 - while f"{esc_name}{i}" in self._rules: + while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: i += 1 - key = f"{esc_name}{i}" + key = f'{esc_name}{i}' self._rules[key] = rule return key - def visit(self, schema: Dict[str, Any], name: str) -> str: - rule_name = name or "root" - - if "$defs" in schema: - # add defs to self._defs for later inlining - for def_name, def_schema in schema["$defs"].items(): - self._defs[def_name] = def_schema + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests + + frag_split = ref.split('#') + base_url = frag_split[0] + + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return target + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref + else: + raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return ' | '.join(( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + )) + + def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: Tuple[str, bool]) -> str: + (txt, is_literal) = s + return "\"" + txt + "\"" if is_literal else txt + + def transform() -> Tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[Tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in groupby(seq, lambda x: x[1]): + if is_literal: + ret.append((''.join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({to_rule(transform())})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) + else: + literal = '' + while i < length: + if pattern[i] == '\\' and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i:i+2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and \ + (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + to_rule(transform()) if self._raw_pattern \ + else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") + + + def _resolve_ref(self, ref): + ref_name = ref.split('/')[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get('type') + schema_format = schema.get('format') + rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + + if (ref := schema.get('$ref')) is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif 'oneOf' in schema or 'anyOf' in schema: + return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + + elif isinstance(schema_type, list): + return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + + elif 'const' in schema: + return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + + elif 'enum' in schema: + rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + return self._add_rule(rule_name, rule) - if "oneOf" in schema or "anyOf" in schema: - rule = " | ".join( - ( - self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') - for i, alt_schema in enumerate( - schema.get("oneOf") or schema["anyOf"] - ) - ) + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + + elif schema_type in (None, 'object') and 'allOf' in schema: + required = set() + properties = [] + hybrid_name = name + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items)) + + ' "]" space') + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + + elif schema_type in (None, 'string') and 'pattern' in schema: + return self._visit_pattern(schema['pattern'], rule_name) + + elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + return self._add_primitive( + 'root' if rule_name == 'root' else schema_format, + PRIMITIVE_RULES['uuid'] ) - return self._add_rule(rule_name, rule) - elif "const" in schema: - return self._add_rule(rule_name, self._format_literal(schema["const"])) + elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: + prim_name = f'{schema_format}-string' + return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) - elif "enum" in schema: - rule = " | ".join((self._format_literal(v) for v in schema["enum"])) - return self._add_rule(rule_name, rule) + elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + min_len = schema.get('minLength', 0) + max_len = schema.get('maxLength') - elif "$ref" in schema: - ref = schema["$ref"] - assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" - # inline $defs - def_name = ref[len("#/$defs/") :] - def_schema = self._defs[def_name] - return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') - - - schema_type: Optional[str] = schema.get("type") # type: ignore - assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" - - if schema_type == "object" and "properties" in schema: - # TODO: `required` keyword - if self._prop_order: - prop_order = self._prop_order - prop_pairs = sorted( - schema["properties"].items(), - # sort by position in prop_order (if specified) then by key - key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), - ) - else: - prop_pairs = schema["properties"].items() + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') - rule = '"{" space' - for i, (prop_name, prop_schema) in enumerate(prop_pairs): - prop_rule_name = self.visit( - prop_schema, f'{name}{"-" if name else ""}{prop_name}' - ) - if i > 0: - rule += ' "," space' - rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' - rule += ' "}" space' - - return self._add_rule(rule_name, rule) + elif (schema_type == 'object') or (len(schema) == 0): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) - elif schema_type == "array" and "items" in schema: - # TODO `prefixItems` keyword - item_rule_name = self.visit( - schema["items"], f'{name}{"-" if name else ""}item' + else: + assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' ) - list_item_operator = f'("," space {item_rule_name})' - successive_items = "" - min_items = schema.get("minItems", 0) - if min_items > 0: - first_item = f"({item_rule_name})" - successive_items = list_item_operator * (min_items - 1) - min_items -= 1 - else: - first_item = f"({item_rule_name})?" - max_items = schema.get("maxItems") - if max_items is not None and max_items > min_items: - successive_items += (list_item_operator + "?") * (max_items - min_items - 1) - else: - successive_items += list_item_operator + "*" - rule = f'"[" space {first_item} {successive_items} "]" space' - return self._add_rule(rule_name, rule) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties == True or isinstance(additional_properties, dict): + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + prop_kv_rule_names["*"] = self._add_rule( + f'{sub_name}-kv', + self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += ' (' + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + if k == '*': + res = self._add_rule( + f'{name}{"-" if name else ""}additional-kvs', + f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' + ) + elif first_is_optional: + res = f'( "," space {kv_rule_name} )?' + else: + res = kv_rule_name + if len(rest) > 0: + res += ' ' + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True) + ) + return res - else: - assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" - return self._add_rule( - "root" if rule_name == "root" else schema_type, - PRIMITIVE_RULES[schema_type], + rule += ' | '.join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) ) + if required_props: + rule += ' )' + rule += ' )?' + + rule += ' "}" space' + + return rule def format_grammar(self): - return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + return '\n'.join( + f'{name} ::= {rule}' + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): From 1b696a8b109880a095aa31e4a53630a79daa4136 Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 16:08:44 -0300 Subject: [PATCH 02/11] debug --- .gitignore | 2 +- llama_cpp/llama_chat_format.py | 437 ++++++++++++++++++++++++++++++++- 2 files changed, 437 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9d68dbcd9..114cb8346 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ *.local - +test.ipynb .python-version .vscode/ diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 70efef866..64e1f92c4 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2852,4 +2852,439 @@ def vicuna_function_calling( "{% if add_generation_prompt %}ASSISTANT\n{% endif %}" # Vicuna adds the role for prompt continuation ) return base_function_calling(end_token="", - **locals()) \ No newline at end of file + **locals()) + +@register_chat_completion_handler("mixtral-function-calling") +def mixtral_function_calling( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs, # type: ignore +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + end_token = "" + function_calling_template = ( + "{% for message in messages %}\n" + "{% if message.role == 'user' %}\n" + "[INST] \n" + "{{ message.content }}\n" + "[/INST]\n" + "{% elif message.role == 'assistant' %}\n" + "[TOOL_CALLS] \n" + "[\n" + " {% for tool_call in message.tool_calls %}\n" + " {\n" + " \"name\": \"{{ tool_call.function.name }}\",\n" + " \"arguments\": {\n" + " {% for arg_key, arg_val in tool_call.arguments.items() %}\n" + " \"{{ arg_key }}\": \"{{ arg_val }}\"{% if not loop.last %},{% endif %}\n" + " {% endfor %}\n" + " },\n" + " \"id\": \"{{ tool_call.id }}\"\n" + " }{% if not loop.last %},{% endif %}\n" + " {% endfor %}\n" + "]\n\n" + "{% elif message.role == 'tool' %}\n" + "[TOOL_RESULTS] \n" + "{\n" + " \"call_id\": \"{{ message.tool_call_id }}\",\n" + " \"content\": {{ message.content }}\n" + "}\n" + "[/TOOL_RESULTS] \n" + "The current temperature in {{ message.location }} is {{ message.content }} degrees Celsius.\n" + "{% endif %}\n" + "{% endfor %}\n" + "[AVAILABLE_TOOLS]\n" + "[\n" + " {% for tool in tools %}\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"{{ tool.function.name }}\",\n" + " \"description\": \"{{ tool.function.description }}\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " {% for param_key, param_spec in tool.function.parameters.items() %}\n" + " \"{{ param_key }}\": {\n" + " \"type\": \"{{ param_spec.type }}\",\n" + " \"description\": \"{{ param_spec.description }}\",\n" + " {% if param_spec.enum %}\"enum\": [{{ param_spec.enum | join(', ') }}],{% endif %}\n" + " }{% if not loop.last %},{% endif %}\n" + " {% endfor %}\n" + " },\n" + " \"required\": [{{ tool.function.required | join(', ') }}]\n" + " }\n" + " }\n" + " }{% if not loop.last %},{% endif %}\n" + " {% endfor %}\n" + "]\n" + "[/AVAILABLE_TOOLS]\n" + ) + template_renderer = jinja2.Environment( + loader=jinja2.BaseLoader(), + autoescape=jinja2.select_autoescape(["html", "xml"]), + undefined=jinja2.DebugUndefined, + ).from_string(function_calling_template) + # Convert legacy functions to tools + if functions is not None: + tools = [ + { + "type": "function", + "function": function, + } + for function in functions + ] + + # Convert legacy function_call to tool_choice + if function_call is not None: + if isinstance(function_call, str) and ( + function_call == "none" or function_call == "auto" + ): + tool_choice = function_call + if isinstance(function_call, dict) and "name" in function_call: + tool_choice = { + "type": "function", + "function": { + "name": function_call["name"], + }, + } + + stop = [stop, end_token] if isinstance(stop, str) else stop + [end_token] if stop else [end_token] + # Case 1: No tool choice by user + if ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "none") + or tools is None + or len(tools) == 0 + ): + prompt = template_renderer.render( + messages=messages, + tools=[], + tool_calls=None, + add_generation_prompt=True, + ) + if response_format is not None and response_format["type"] == "json_object": + grammar = _grammar_for_response_format(response_format) + + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Case 2: Tool choice by user + if isinstance(tool_choice, dict): + tool_name = tool_choice["function"]["name"] + tool = next( + (tool for tool in tools if tool["function"]["name"] == tool_name), None + ) + if tool is None: + raise ValueError(f"Tool with name '{tool_name}' not found in tools") + prompt = template_renderer.render( + messages=messages, + tools=tools, + tool_calls=True, + add_generation_prompt=True, + ) + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + if llama.verbose: + print( + "Failed to parse function body as JSON schema, falling back to default grammar" + ) + print(e) + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) + return _convert_completion_to_chat_function( + tool_name, completion_or_chunks, stream + ) + + # Case 3: Automatic tool choice + assert isinstance(tool_choice, str) and tool_choice == "auto" + function_names = " | ".join( + [f'''[{{"name":"{tool['function']['name']}"''' for tool in tools] + ) + + initial_gbnf_tool_grammar = ( + """root ::= functions | [INST]\n""" + f"""functions ::= [TOOL_CALLS] {function_names}\n""" + ) + + follow_up_gbnf_tool_grammar = ( + f"""root ::= functions | ""\n""" + f"""functions ::= {function_names}\n""" + ) + + + prompt = template_renderer.render( + messages=messages, + tools=tools, + tool_calls=True, + add_generation_prompt=True, + ) + print(prompt) + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=False, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), + ) + completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore + text = completion["choices"][0]["text"] + print(text) + if "[INST]" in text: + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt + "message:\n", + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=stream, + stop=[""], + logprobs=top_logprobs if logprobs else None, + max_tokens=None, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + ), + stream=stream, + ) + + # One or more function calls + tool_name = text[len("functions.") :].replace(":", "") + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + if not stream: + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + while tool is not None: + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + if llama.verbose: + print( + "Failed to parse function body as JSON schema, falling back to default grammar" + ) + print(e) + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=False, + stop=stop, + max_tokens=None, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) + completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + completions.append(completion_or_chunks) + completions_tool_name.append(tool_name) + prompt += completion_or_chunks["choices"][0]["text"] + prompt += "\n" + response = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=False, + stop=stop, + max_tokens=None, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + + ) + + response = cast(llama_types.CreateCompletionResponse, response) + + if response["choices"][0]["text"] == "": + break + tool_name = response["choices"][0]["text"][len("functions.") :].replace(":", "") + tool = next( + (tool for tool in tools if tool["function"]["name"] == tool_name), None + ) + + # Merge completions + function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = { + "function_call": { + "name": tool_name, + "arguments": completions[0]["choices"][0]["text"], + } + } if len(completions) == 1 else {} + return { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "logprobs": completion["choices"][0]["logprobs"], + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_" + + f"_{i}_" + + tool_name + + "_" + + completion["id"], + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + for i, (tool_name, completion) in enumerate( + zip(completions_tool_name, completions) + ) + ], + **function_call_dict + }, + } + ], + "usage": { + "completion_tokens": sum( + completion["usage"]["completion_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "prompt_tokens": sum( + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "total_tokens": sum( + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions + ), + }, + } + + raise ValueError("Automatic streaming tool choice is not supported") + \ No newline at end of file From ca55725ef3ff86c4de08a8ff2825617c6dd92b0d Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 18:14:04 -0300 Subject: [PATCH 03/11] debug --- llama_cpp/llama_chat_format.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 64e1f92c4..51184e65a 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2910,6 +2910,7 @@ def mixtral_function_calling( " }{% if not loop.last %},{% endif %}\n" " {% endfor %}\n" "]\n\n" + "When you are done with the function calls, end the message with [/TOOL_CALLS]\n" "{% elif message.role == 'tool' %}\n" "[TOOL_RESULTS] \n" "{\n" @@ -3082,9 +3083,10 @@ def mixtral_function_calling( """root ::= functions | [INST]\n""" f"""functions ::= [TOOL_CALLS] {function_names}\n""" ) + print(initial_gbnf_tool_grammar) follow_up_gbnf_tool_grammar = ( - f"""root ::= functions | ""\n""" + f"""root ::= functions | "[/TOOL_CALLS]"\n""" f"""functions ::= {function_names}\n""" ) @@ -3222,7 +3224,7 @@ def mixtral_function_calling( response = cast(llama_types.CreateCompletionResponse, response) - if response["choices"][0]["text"] == "": + if response["choices"][0]["text"] == "[/TOOL_CALLS]": break tool_name = response["choices"][0]["text"][len("functions.") :].replace(":", "") tool = next( From 9969526698dabf4a16fd5152fc5ab712077cc358 Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 19:00:11 -0300 Subject: [PATCH 04/11] debug --- llama_cpp/llama_chat_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 51184e65a..aac6e03c6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3076,7 +3076,7 @@ def mixtral_function_calling( # Case 3: Automatic tool choice assert isinstance(tool_choice, str) and tool_choice == "auto" function_names = " | ".join( - [f'''[{{"name":"{tool['function']['name']}"''' for tool in tools] + [f'''"name:{tool['function']['name']}:"''' for tool in tools] ) initial_gbnf_tool_grammar = ( From 6197f6223c2acdb4c920af257493be8667b5a563 Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 19:11:51 -0300 Subject: [PATCH 05/11] debug --- llama_cpp/llama_chat_format.py | 51 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index aac6e03c6..a501ed03b 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2921,32 +2921,31 @@ def mixtral_function_calling( "The current temperature in {{ message.location }} is {{ message.content }} degrees Celsius.\n" "{% endif %}\n" "{% endfor %}\n" - "[AVAILABLE_TOOLS]\n" - "[\n" - " {% for tool in tools %}\n" - " {\n" - " \"type\": \"function\",\n" - " \"function\": {\n" - " \"name\": \"{{ tool.function.name }}\",\n" - " \"description\": \"{{ tool.function.description }}\",\n" - " \"parameters\": {\n" - " \"type\": \"object\",\n" - " \"properties\": {\n" - " {% for param_key, param_spec in tool.function.parameters.items() %}\n" - " \"{{ param_key }}\": {\n" - " \"type\": \"{{ param_spec.type }}\",\n" - " \"description\": \"{{ param_spec.description }}\",\n" - " {% if param_spec.enum %}\"enum\": [{{ param_spec.enum | join(', ') }}],{% endif %}\n" - " }{% if not loop.last %},{% endif %}\n" - " {% endfor %}\n" - " },\n" - " \"required\": [{{ tool.function.required | join(', ') }}]\n" - " }\n" - " }\n" - " }{% if not loop.last %},{% endif %}\n" - " {% endfor %}\n" - "]\n" - "[/AVAILABLE_TOOLS]\n" + "[AVAILABLE_TOOLS]" + "[" + " {% for tool in tools %}" + " {" + " \"type\": \"function\"," + " \"function\": {" + " \"name\": \"{{ tool.function.name }}\"," + " \"description\": \"{{ tool.function.description }}\"," + " \"parameters\": {" + " \"type\": \"object\"," + " \"properties\": {" + " {% for param_key, param_spec in tool.function.parameters.properties.items() %}" + " \"{{ param_key }}\": {" + " \"type\": \"{{ param_spec.type }}\"," + " \"description\": \"{{ param_spec.description }}\"" + " }{% if not loop.last %},{% endif %}" + " {% endfor %}" + " }," + " \"required\": [{{ tool.function.parameters.required | join(', ') }}]" + " }" + " }" + " }{% if not loop.last %},{% endif %}" + " {% endfor %}" + "]" + "[/AVAILABLE_TOOLS]" ) template_renderer = jinja2.Environment( loader=jinja2.BaseLoader(), From 6e0369254295f5e32672bc23dd47c58a75b4f82d Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 19:16:19 -0300 Subject: [PATCH 06/11] debug --- llama_cpp/llama_chat_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index a501ed03b..b75805b39 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3075,7 +3075,7 @@ def mixtral_function_calling( # Case 3: Automatic tool choice assert isinstance(tool_choice, str) and tool_choice == "auto" function_names = " | ".join( - [f'''"name:{tool['function']['name']}:"''' for tool in tools] + [f'''"name:{tool['function']['name']}"''' for tool in tools] ) initial_gbnf_tool_grammar = ( From 7ce06677ac2b4a890404049c1c369b80fb215dfa Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Wed, 17 Apr 2024 19:19:27 -0300 Subject: [PATCH 07/11] debug --- llama_cpp/llama_chat_format.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index b75805b39..98b657fd3 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3116,9 +3116,9 @@ def mixtral_function_calling( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose - ), + # grammar=llama_grammar.LlamaGrammar.from_string( + # initial_gbnf_tool_grammar, verbose=llama.verbose + # ), ) completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore text = completion["choices"][0]["text"] From 15dd7aa9dab3a1dd19069b1a625263aa8bae19bc Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 19:28:04 -0300 Subject: [PATCH 08/11] debug --- llama_cpp/llama_chat_format.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 98b657fd3..19e8f2ebf 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3074,9 +3074,10 @@ def mixtral_function_calling( # Case 3: Automatic tool choice assert isinstance(tool_choice, str) and tool_choice == "auto" + function_names = " | ".join( - [f'''"name:{tool['function']['name']}"''' for tool in tools] - ) + [f'''"{{ "name":{tool['function']['name']}"''' for tool in tools] +) initial_gbnf_tool_grammar = ( """root ::= functions | [INST]\n""" From 014575f73901ef379b3136c8b64a1189086fe6ae Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Wed, 17 Apr 2024 19:28:58 -0300 Subject: [PATCH 09/11] debug --- llama_cpp/llama_chat_format.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 19e8f2ebf..14eafb7d8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3117,9 +3117,9 @@ def mixtral_function_calling( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, - # grammar=llama_grammar.LlamaGrammar.from_string( - # initial_gbnf_tool_grammar, verbose=llama.verbose - # ), + grammar=llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), ) completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore text = completion["choices"][0]["text"] From 454b5e31e93234d471adef6942264ac3b54c7086 Mon Sep 17 00:00:00 2001 From: Lucca Zenobio Date: Wed, 17 Apr 2024 19:47:46 -0300 Subject: [PATCH 10/11] debug --- llama_cpp/llama_chat_format.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 14eafb7d8..19e8f2ebf 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3117,9 +3117,9 @@ def mixtral_function_calling( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose - ), + # grammar=llama_grammar.LlamaGrammar.from_string( + # initial_gbnf_tool_grammar, verbose=llama.verbose + # ), ) completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore text = completion["choices"][0]["text"] From 08cf4f7345fdce6fa117f891c85636ab9a5790ac Mon Sep 17 00:00:00 2001 From: lucca Date: Wed, 17 Apr 2024 19:55:22 -0300 Subject: [PATCH 11/11] up --- llama_cpp/llama_chat_format.py | 51 +++++++++++++++++----------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 19e8f2ebf..7d9211180 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2890,6 +2890,31 @@ def mixtral_function_calling( ]: end_token = "" function_calling_template = ( + "[AVAILABLE_TOOLS]" + "[" + " {% for tool in tools %}" + " {" + " \"type\": \"function\"," + " \"function\": {" + " \"name\": \"{{ tool.function.name }}\"," + " \"description\": \"{{ tool.function.description }}\"," + " \"parameters\": {" + " \"type\": \"object\"," + " \"properties\": {" + " {% for param_key, param_spec in tool.function.parameters.properties.items() %}" + " \"{{ param_key }}\": {" + " \"type\": \"{{ param_spec.type }}\"," + " \"description\": \"{{ param_spec.description }}\"" + " }{% if not loop.last %},{% endif %}" + " {% endfor %}" + " }," + " \"required\": [{{ tool.function.parameters.required | join(', ') }}]" + " }" + " }" + " }{% if not loop.last %},{% endif %}" + " {% endfor %}" + "]" + "[/AVAILABLE_TOOLS]" "{% for message in messages %}\n" "{% if message.role == 'user' %}\n" "[INST] \n" @@ -2921,31 +2946,7 @@ def mixtral_function_calling( "The current temperature in {{ message.location }} is {{ message.content }} degrees Celsius.\n" "{% endif %}\n" "{% endfor %}\n" - "[AVAILABLE_TOOLS]" - "[" - " {% for tool in tools %}" - " {" - " \"type\": \"function\"," - " \"function\": {" - " \"name\": \"{{ tool.function.name }}\"," - " \"description\": \"{{ tool.function.description }}\"," - " \"parameters\": {" - " \"type\": \"object\"," - " \"properties\": {" - " {% for param_key, param_spec in tool.function.parameters.properties.items() %}" - " \"{{ param_key }}\": {" - " \"type\": \"{{ param_spec.type }}\"," - " \"description\": \"{{ param_spec.description }}\"" - " }{% if not loop.last %},{% endif %}" - " {% endfor %}" - " }," - " \"required\": [{{ tool.function.parameters.required | join(', ') }}]" - " }" - " }" - " }{% if not loop.last %},{% endif %}" - " {% endfor %}" - "]" - "[/AVAILABLE_TOOLS]" + ) template_renderer = jinja2.Environment( loader=jinja2.BaseLoader(),