From 8ce2c538efe9ce1491447f8bae31f78aa756b5ae Mon Sep 17 00:00:00 2001 From: Giuseppe Zileni Date: Mon, 10 Mar 2025 23:05:10 +0100 Subject: [PATCH] fixed kwargs ask method --- src/vanna/base/base.py | 5 +-- src/vanna/local.py | 10 +++++- src/vanna/ollama/ollama.py | 12 +++++++ src/vanna/remote.py | 64 ++++++++++++++++++++++++++++++++++++++ src/vanna/utils.py | 26 ++++++++++++++-- 5 files changed, 112 insertions(+), 5 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 4c05de58..b77bdf69 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -54,7 +54,7 @@ import sqlite3 import traceback from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Any from urllib.parse import urlparse import pandas as pd @@ -1670,6 +1670,7 @@ def ask( auto_train: bool = True, visualize: bool = True, # if False, will not generate plotly code allow_llm_to_see_data: bool = False, + **kwargs: Any, ) -> Union[ Tuple[ Union[str, None], @@ -1700,7 +1701,7 @@ def ask( question = input("Enter a question: ") try: - sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data) + sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data, **kwargs) except Exception as e: print(e) return None, None, None diff --git a/src/vanna/local.py b/src/vanna/local.py index 498fb31c..ead4a1eb 100644 --- a/src/vanna/local.py +++ b/src/vanna/local.py @@ -1,8 +1,16 @@ from .chromadb.chromadb_vector import ChromaDB_VectorStore from .openai.openai_chat import OpenAI_Chat - class LocalContext_OpenAI(ChromaDB_VectorStore, OpenAI_Chat): + """ + LocalContext_OpenAI class that combines functionalities from ChromaDB_VectorStore and OpenAI_Chat. + + Attributes: + config (dict, optional): Configuration dictionary for initializing the parent classes. + + Methods: + __init__(config=None): Initializes the LocalContext_OpenAI instance with the given configuration. + """ def __init__(self, config=None): ChromaDB_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index bcd6a2e8..97790f5c 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -86,6 +86,18 @@ def extract_sql(self, llm_response): return llm_response def submit_prompt(self, prompt, **kwargs) -> str: + """ + Submits a prompt to the Ollama client and returns the response content. + Args: + prompt (dict): The prompt message to be sent to the Ollama client. + **kwargs: Additional keyword arguments. + Returns: + str: The content of the response message from the Ollama client. + Logs: + - Ollama parameters including model, options, and keep_alive status. + - The content of the prompt being submitted. + - The response received from the Ollama client. + """ self.log( f"Ollama parameters:\n" f"model={self.model},\n" diff --git a/src/vanna/remote.py b/src/vanna/remote.py index c3b0220d..4919979a 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -38,7 +38,36 @@ class VannaDefault(VannaDB_VectorStore): + """ + VannaDefault is a class that extends VannaDB_VectorStore and provides methods to interact with the Vanna AI system. + Attributes: + model (str): The model identifier for the Vanna AI. + api_key (str): The API key for authenticating with the Vanna AI. + config (dict, optional): Configuration dictionary for additional settings. + Methods: + __init__(model: str, api_key: str, config=None): + Initializes the VannaDefault instance with the specified model, API key, and optional configuration. + system_message(message: str) -> any: + Creates a system message dictionary with the specified message content. + user_message(message: str) -> any: + Creates a user message dictionary with the specified message content. + assistant_message(message: str) -> any: + Creates an assistant message dictionary with the specified message content. + submit_prompt(prompt, **kwargs) -> str: + Submits a prompt to the Vanna AI system and returns the response as a string. + """ def __init__(self, model: str, api_key: str, config=None): + """ + Initialize the remote Vanna model. + Args: + model (str): The model identifier. + api_key (str): The API key for authentication. + config (dict, optional): Configuration dictionary. Defaults to None. + Attributes: + _model (str): The model identifier. + _api_key (str): The API key for authentication. + _endpoint (str): The endpoint URL for the Vanna API. + """ VannaBase.__init__(self, config=config) VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config) @@ -52,15 +81,50 @@ def __init__(self, model: str, api_key: str, config=None): ) def system_message(self, message: str) -> any: + """ + Creates a system message dictionary. + + Args: + message (str): The content of the system message. + + Returns: + dict: A dictionary with the role set to "system" and the content set to the provided message. + """ return {"role": "system", "content": message} def user_message(self, message: str) -> any: + """ + Constructs a dictionary representing a user message. + + Args: + message (str): The content of the user's message. + + Returns: + dict: A dictionary with keys 'role' and 'content', where 'role' is set to 'user' and 'content' is set to the provided message. + """ return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: + """ + Constructs a dictionary representing an assistant message. + + Args: + message (str): The message content from the assistant. + + Returns: + dict: A dictionary with keys 'role' and 'content', where 'role' is set to 'assistant' and 'content' is the provided message. + """ return {"role": "assistant", "content": message} def submit_prompt(self, prompt, **kwargs) -> str: + """ + Submits a prompt to a remote service and returns the result. + Args: + prompt (str): The prompt to be submitted. + **kwargs: Additional keyword arguments. + Returns: + str: The result from the remote service, or None if the result is not present. + """ # JSON-ify the prompt json_prompt = json.dumps(prompt, ensure_ascii=False) diff --git a/src/vanna/utils.py b/src/vanna/utils.py index d3d2d54e..5f9d7b8a 100644 --- a/src/vanna/utils.py +++ b/src/vanna/utils.py @@ -3,11 +3,18 @@ import re import uuid from typing import Union - from .exceptions import ImproperlyConfigured, ValidationError - def validate_config_path(path): + """ + Validates the given configuration file path. + This function checks if the provided path exists, is a file, and is readable. + If any of these conditions are not met, it raises an ImproperlyConfigured exception. + Args: + path (str): The path to the configuration file. + Raises: + ImproperlyConfigured: If the path does not exist, is not a file, or is not readable. + """ if not os.path.exists(path): raise ImproperlyConfigured( f'No such configuration file: {path}' @@ -25,6 +32,21 @@ def validate_config_path(path): def sanitize_model_name(model_name): + """ + Sanitizes the given model name by performing the following operations: + - Converts the model name to lowercase. + - Replaces spaces with hyphens. + - Replaces multiple consecutive hyphens with a single hyphen. + - Replaces underscores with hyphens if both underscores and hyphens are present. + - Removes special characters, allowing only alphanumeric characters, hyphens, and underscores. + - Removes hyphens or underscores from the beginning or end of the model name. + Args: + model_name (str): The model name to sanitize. + Returns: + str: The sanitized model name. + Raises: + ValidationError: If an error occurs during sanitization. + """ try: model_name = model_name.lower()