diff --git a/defog/llm/utils_function_calling.py b/defog/llm/utils_function_calling.py index e9c5dda..2b027de 100644 --- a/defog/llm/utils_function_calling.py +++ b/defog/llm/utils_function_calling.py @@ -115,32 +115,43 @@ def convert_tool_choice(tool_choice: str, tool_name_list: List[str], model: str) }, "custom": {"type": "tool", "name": tool_choice}, }, - "gemini": { - "prefixes": ["gemini"], - "choices": { - "auto": types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="AUTO") - ), - "required": types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="ANY") - ), - "any": types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="ANY") - ), - "none": types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="NONE") - ), - }, - "custom": types.ToolConfig( - function_calling_config=types.FunctionCallingConfig( - mode="ANY", allowed_function_names=[tool_choice] - ) - ), - }, + "gemini": {"prefixes": ["gemini"]}, } for model_type, config in model_map.items(): if any(model.startswith(prefix) for prefix in config["prefixes"]): + if model_type == "gemini": + from google.genai import types + + config = { + "choices": { + "auto": types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="AUTO" + ) + ), + "required": types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="ANY" + ) + ), + "any": types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="ANY" + ) + ), + "none": types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="NONE" + ) + ), + }, + "custom": types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="ANY", allowed_function_names=[tool_choice] + ) + ), + } if tool_choice not in config["choices"]: # Validate custom tool_choice if tool_choice not in tool_name_list: