Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit rows for queries #812

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 119 additions & 24 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
from ..types import TrainingPlan, TrainingPlanItem, TableMetadata
from ..utils import validate_config_path


Expand Down Expand Up @@ -210,6 +210,54 @@ def extract_sql(self, llm_response: str) -> str:

return llm_response

def extract_table_metadata(ddl: str) -> TableMetadata:
"""
Example:
```python
vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)")
```

Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata.
Override this function if your DDL statements need custom extraction logic.

Args:
ddl (str): The DDL statement.

Returns:
TableMetadata: The extracted table metadata.
"""
pattern_with_catalog_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_table = re.compile(
r'CREATE TABLE\s+(\w+)\s*\(',
re.IGNORECASE
)

match_with_catalog_schema = pattern_with_catalog_schema.search(ddl)
match_with_schema = pattern_with_schema.search(ddl)
match_with_table = pattern_with_table.search(ddl)

if match_with_catalog_schema:
catalog = match_with_catalog_schema.group(1)
schema = match_with_catalog_schema.group(2)
table_name = match_with_catalog_schema.group(3)
return TableMetadata(catalog, schema, table_name)
elif match_with_schema:
schema = match_with_schema.group(1)
table_name = match_with_schema.group(2)
return TableMetadata(None, schema, table_name)
elif match_with_table:
table_name = match_with_table.group(1)
return TableMetadata(None, None, table_name)
else:
return TableMetadata()

def is_sql_valid(self, sql: str) -> bool:
"""
Example:
Expand Down Expand Up @@ -306,7 +354,7 @@ def generate_followup_questions(

message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n"
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
Expand Down Expand Up @@ -395,6 +443,31 @@ def get_related_ddl(self, question: str, **kwargs) -> list:
"""
pass

@abstractmethod
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
"""
This method is used to get similar tables metadata.

Args:
engine (str): The database engine.
catalog (str): The catalog.
schema (str): The schema.
table_name (str): The table name.
ddl (str): The DDL statement.
size (int): The number of tables to return.

Returns:
list: A list of tables metadata.
"""
pass

@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
Expand Down Expand Up @@ -423,12 +496,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.

Args:
ddl (str): The DDL statement to add.
engine (str): The database engine that the DDL statement applies to.

Returns:
str: The ID of the training data that was added.
Expand Down Expand Up @@ -689,9 +763,6 @@ def generate_question(self, sql: str, **kwargs) -> str:
return response

def _extract_python_code(self, markdown_string: str) -> str:
# Strip whitespace to avoid indentation errors in LLM-generated code
markdown_string = markdown_string.strip()

# Regex pattern to match Python code blocks
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"

Expand Down Expand Up @@ -800,7 +871,7 @@ def connect_to_snowflake(
**kwargs
)

def run_sql_snowflake(sql: str) -> pd.DataFrame:
def run_sql_snowflake(sql: str, rows_limit: int = None) -> pd.DataFrame:
cs = conn.cursor()

if role is not None:
Expand All @@ -809,13 +880,15 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:
if warehouse is not None:
cs.execute(f"USE WAREHOUSE {warehouse}")
cs.execute(f"USE DATABASE {database}")
cs.execute(sql)

cur = cs.execute(sql)

results = cur.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description])
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])

return df

Expand Down Expand Up @@ -949,13 +1022,16 @@ def connect_to_db():
user=user, password=password, port=port, **kwargs)


def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
def run_sql_postgres(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]:
conn = None
try:
conn = connect_to_db() # Initial connection attempt
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
Expand All @@ -968,7 +1044,10 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
conn = connect_to_db()
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
Expand Down Expand Up @@ -1051,13 +1130,16 @@ def connect_to_mysql(
except pymysql.Error as e:
raise ValidationError(e)

def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
def run_sql_mysql(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]:
if conn:
try:
conn.ping(reconnect=True)
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
Expand Down Expand Up @@ -1218,7 +1300,7 @@ def connect_to_oracle(
except oracledb.Error as e:
raise ValidationError(e)

def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
def run_sql_oracle(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]:
if conn:
try:
sql = sql.rstrip()
Expand All @@ -1227,7 +1309,10 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:

cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
Expand Down Expand Up @@ -1519,7 +1604,7 @@ def connect_to_presto(
except presto.Error as e:
raise ValidationError(e)

def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
def run_sql_presto(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]:
if conn:
try:
sql = sql.rstrip()
Expand All @@ -1528,7 +1613,10 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
sql = sql[:-1]
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
Expand Down Expand Up @@ -1620,12 +1708,15 @@ def connect_to_hive(
except hive.Error as e:
raise ValidationError(e)

def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
def run_sql_hive(sql: str, rows_limit: int = None) -> Union[pd.DataFrame, None]:
if conn:
try:
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
if rows_limit != None:
results = cs.fetchmany(numRows=rows_limit)
else:
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
Expand Down Expand Up @@ -1781,6 +1872,7 @@ def train(
question: str = None,
sql: str = None,
ddl: str = None,
engine: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
Expand All @@ -1801,8 +1893,11 @@ def train(
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
engine (str): The database engine.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
Returns:
str: The training pl
"""

if question and not sql:
Expand All @@ -1820,12 +1915,12 @@ def train(

if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
return self.add_ddl(ddl=ddl, engine=engine)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
self.add_ddl(ddl=item.item_value, engine=engine)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
Expand Down