diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 4c05de58..17983e81 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -65,7 +65,7 @@ import sqlparse from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError -from ..types import TrainingPlan, TrainingPlanItem +from ..types import TrainingPlan, TrainingPlanItem, TableMetadata from ..utils import validate_config_path @@ -210,6 +210,54 @@ def extract_sql(self, llm_response: str) -> str: return llm_response + def extract_table_metadata(ddl: str) -> TableMetadata: + """ + Example: + ```python + vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)") + ``` + + Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata. + Override this function if your DDL statements need custom extraction logic. + + Args: + ddl (str): The DDL statement. + + Returns: + TableMetadata: The extracted table metadata. + """ + pattern_with_catalog_schema = re.compile( + r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(', + re.IGNORECASE + ) + pattern_with_schema = re.compile( + r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(', + re.IGNORECASE + ) + pattern_with_table = re.compile( + r'CREATE TABLE\s+(\w+)\s*\(', + re.IGNORECASE + ) + + match_with_catalog_schema = pattern_with_catalog_schema.search(ddl) + match_with_schema = pattern_with_schema.search(ddl) + match_with_table = pattern_with_table.search(ddl) + + if match_with_catalog_schema: + catalog = match_with_catalog_schema.group(1) + schema = match_with_catalog_schema.group(2) + table_name = match_with_catalog_schema.group(3) + return TableMetadata(catalog, schema, table_name) + elif match_with_schema: + schema = match_with_schema.group(1) + table_name = match_with_schema.group(2) + return TableMetadata(None, schema, table_name) + elif match_with_table: + table_name = match_with_table.group(1) + return TableMetadata(None, None, table_name) + else: + return TableMetadata() + def is_sql_valid(self, sql: str) -> bool: """ Example: @@ -306,7 +354,7 @@ def generate_followup_questions( message_log = [ self.system_message( - f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n" + f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" ), self.user_message( f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + @@ -395,6 +443,31 @@ def get_related_ddl(self, question: str, **kwargs) -> list: """ pass + @abstractmethod + def search_tables_metadata(self, + engine: str = None, + catalog: str = None, + schema: str = None, + table_name: str = None, + ddl: str = None, + size: int = 10, + **kwargs) -> list: + """ + This method is used to get similar tables metadata. + + Args: + engine (str): The database engine. + catalog (str): The catalog. + schema (str): The schema. + table_name (str): The table name. + ddl (str): The DDL statement. + size (int): The number of tables to return. + + Returns: + list: A list of tables metadata. + """ + pass + @abstractmethod def get_related_documentation(self, question: str, **kwargs) -> list: """ @@ -423,12 +496,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: pass @abstractmethod - def add_ddl(self, ddl: str, **kwargs) -> str: + def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str: """ This method is used to add a DDL statement to the training data. Args: ddl (str): The DDL statement to add. + engine (str): The database engine that the DDL statement applies to. Returns: str: The ID of the training data that was added. @@ -689,9 +763,6 @@ def generate_question(self, sql: str, **kwargs) -> str: return response def _extract_python_code(self, markdown_string: str) -> str: - # Strip whitespace to avoid indentation errors in LLM-generated code - markdown_string = markdown_string.strip() - # Regex pattern to match Python code blocks pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" @@ -800,7 +871,7 @@ def connect_to_snowflake( **kwargs ) - def run_sql_snowflake(sql: str) -> pd.DataFrame: + def run_sql_snowflake(sql: str, rows_limit: int = None) -> pd.DataFrame: cs = conn.cursor() if role is not None: @@ -809,13 +880,15 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame: if warehouse is not None: cs.execute(f"USE WAREHOUSE {warehouse}") cs.execute(f"USE DATABASE {database}") + cs.execute(sql) - cur = cs.execute(sql) - - results = cur.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description]) + df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) return df @@ -949,13 +1022,16 @@ def connect_to_db(): user=user, password=password, port=port, **kwargs) - def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: + def run_sql_postgres(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]: conn = None try: conn = connect_to_db() # Initial connection attempt cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) @@ -968,7 +1044,10 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: conn = connect_to_db() cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) @@ -1051,13 +1130,16 @@ def connect_to_mysql( except pymysql.Error as e: raise ValidationError(e) - def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: + def run_sql_mysql(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]: if conn: try: conn.ping(reconnect=True) cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( @@ -1218,7 +1300,7 @@ def connect_to_oracle( except oracledb.Error as e: raise ValidationError(e) - def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: + def run_sql_oracle(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]: if conn: try: sql = sql.rstrip() @@ -1227,7 +1309,10 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( @@ -1519,7 +1604,7 @@ def connect_to_presto( except presto.Error as e: raise ValidationError(e) - def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: + def run_sql_presto(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]: if conn: try: sql = sql.rstrip() @@ -1528,7 +1613,10 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: sql = sql[:-1] cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( @@ -1620,12 +1708,15 @@ def connect_to_hive( except hive.Error as e: raise ValidationError(e) - def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: + def run_sql_hive(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]: if conn: try: cs = conn.cursor() cs.execute(sql) - results = cs.fetchall() + if rows_limit != None: + results = cs.fetchmany(size=rows_limit) + else: + results = cs.fetchall() # Create a pandas dataframe from the results df = pd.DataFrame( @@ -1781,6 +1872,7 @@ def train( question: str = None, sql: str = None, ddl: str = None, + engine: str = None, documentation: str = None, plan: TrainingPlan = None, ) -> str: @@ -1801,8 +1893,11 @@ def train( question (str): The question to train on. sql (str): The SQL query to train on. ddl (str): The DDL statement. + engine (str): The database engine. documentation (str): The documentation to train on. plan (TrainingPlan): The training plan to train on. + Returns: + str: The training pl """ if question and not sql: @@ -1820,12 +1915,12 @@ def train( if ddl: print("Adding ddl:", ddl) - return self.add_ddl(ddl) + return self.add_ddl(ddl=ddl, engine=engine) if plan: for item in plan._plan: if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: - self.add_ddl(item.item_value) + self.add_ddl(ddl=item.item_value, engine=engine) elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: self.add_documentation(item.item_value) elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: