Skip to content

Commit b24d9c8

Browse files
DH-5735/add support for multiple schemas for agents
1 parent 4eb7a3e commit b24d9c8

12 files changed

+178
-54
lines changed

Diff for: dataherald/api/types/requests.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
class PromptRequest(BaseModel):
77
text: str
88
db_connection_id: str
9+
schemas: list[str] | None
910
metadata: dict | None
1011

1112

Diff for: dataherald/api/types/responses.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def created_at_as_string(cls, v):
2525
class PromptResponse(BaseResponse):
2626
text: str
2727
db_connection_id: str
28+
schemas: list[str] | None
2829

2930

3031
class SQLGenerationResponse(BaseResponse):

Diff for: dataherald/services/prompts.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DatabaseConnectionRepository,
55
)
66
from dataherald.repositories.prompts import PromptNotFoundError, PromptRepository
7+
from dataherald.sql_database.services.database_connection import SchemaNotSupportedError
78
from dataherald.types import Prompt
89

910

@@ -22,9 +23,16 @@ def create(self, prompt_request: PromptRequest) -> Prompt:
2223
f"Database connection {prompt_request.db_connection_id} not found"
2324
)
2425

26+
if not db_connection.schemas and prompt_request.schemas:
27+
raise SchemaNotSupportedError(
28+
"Schema not supported for this db",
29+
description=f"The {db_connection.dialect} dialect doesn't support schemas",
30+
)
31+
2532
prompt = Prompt(
2633
text=prompt_request.text,
2734
db_connection_id=prompt_request.db_connection_id,
35+
schemas=prompt_request.schemas,
2836
metadata=prompt_request.metadata,
2937
)
3038
return self.prompt_repository.insert(prompt)

Diff for: dataherald/sql_generator/__init__.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from langchain.agents.agent import AgentExecutor
1212
from langchain.callbacks.base import BaseCallbackHandler
1313
from langchain.schema import AgentAction, LLMResult
14-
from langchain.schema.messages import BaseMessage
1514
from langchain_community.callbacks import get_openai_callback
1615

1716
from dataherald.config import Component, System
17+
from dataherald.db_scanner.models.types import TableDescription
1818
from dataherald.model.chat_model import ChatModel
1919
from dataherald.repositories.sql_generations import (
2020
SQLGenerationRepository,
@@ -62,6 +62,21 @@ def remove_markdown(self, query: str) -> str:
6262
return matches[0].strip()
6363
return query
6464

65+
@staticmethod
66+
def get_table_schema(table_name: str, db_scan: List[TableDescription]) -> str:
67+
for table in db_scan:
68+
if table.table_name == table_name:
69+
return table.schema_name
70+
return ""
71+
72+
@staticmethod
73+
def filter_tables_by_schema(
74+
db_scan: List[TableDescription], prompt: Prompt
75+
) -> List[TableDescription]:
76+
if prompt.schemas:
77+
return [table for table in db_scan if table.schema_name in prompt.schemas]
78+
return db_scan
79+
6580
def format_sql_query_intermediate_steps(self, step: str) -> str:
6681
pattern = r"```sql(.*?)```"
6782

Diff for: dataherald/sql_generator/dataherald_finetuning_agent.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s
190190
tables = Parser(example["sql"]).tables
191191
except Exception as e:
192192
logger.error(f"Error parsing SQL: {str(e)}")
193-
most_similar_tables.update(tables)
194-
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
193+
for table in tables:
194+
found_tables = df[df.table_name == table]
195+
for _, row in found_tables.iterrows():
196+
most_similar_tables.add((row["schema_name"], row["table_name"]))
197+
df.drop(
198+
df[
199+
df.table_name.isin([table[1] for table in most_similar_tables])
200+
].index,
201+
inplace=True,
202+
)
195203
return most_similar_tables
196204

197205
@catch_exceptions()
198-
def _run(
206+
def _run( # noqa: PLR0912
199207
self,
200208
user_question: str,
201209
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
@@ -214,9 +222,12 @@ def _run(
214222
table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}"
215223
else:
216224
table_rep = f"Table {table.table_name} contain columns: [{col_rep}]"
217-
table_representations.append([table.table_name, table_rep])
225+
table_representations.append(
226+
[table.schema_name, table.table_name, table_rep]
227+
)
218228
df = pd.DataFrame(
219-
table_representations, columns=["table_name", "table_representation"]
229+
table_representations,
230+
columns=["schema_name", "table_name", "table_representation"],
220231
)
221232
df["table_embedding"] = self.get_docs_embedding(df.table_representation)
222233
df["similarities"] = df.table_embedding.apply(
@@ -227,12 +238,20 @@ def _run(
227238
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
228239
table_relevance = ""
229240
for _, row in df.iterrows():
230-
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
241+
if row["schema_name"] is not None:
242+
table_name = row["schema_name"] + "." + row["table_name"]
243+
else:
244+
table_name = row["table_name"]
245+
table_relevance += (
246+
f'Table: `{table_name}`, relevance score: {row["similarities"]}\n'
247+
)
231248
if len(most_similar_tables) > 0:
232249
for table in most_similar_tables:
233-
table_relevance += (
234-
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
235-
)
250+
if table[0] is not None:
251+
table_name = table[0] + "." + table[1]
252+
else:
253+
table_name = table[1]
254+
table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n"
236255
return table_relevance
237256

238257
async def _arun(
@@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
358377
db_scan: List[TableDescription]
359378

360379
@catch_exceptions()
361-
def _run(
380+
def _run( # noqa: C901
362381
self,
363382
table_names: str,
364383
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
365384
) -> str:
366385
"""Get the schema for tables in a comma-separated list."""
367386
table_names_list = table_names.split(", ")
368-
table_names_list = [
369-
replace_unprocessable_characters(table_name)
370-
for table_name in table_names_list
371-
]
387+
processed_table_names = []
388+
for table in table_names_list:
389+
formatted_table = replace_unprocessable_characters(table)
390+
if "." in formatted_table:
391+
processed_table_names.append(formatted_table.split(".")[1])
392+
else:
393+
processed_table_names.append(formatted_table)
372394
tables_schema = ""
373395
for table in self.db_scan:
374-
if table.table_name in table_names_list:
396+
if table.table_name in processed_table_names:
375397
tables_schema += "```sql\n"
376398
tables_schema += table.table_schema + "\n"
377399
descriptions = []
378400
if table.description is not None:
379-
descriptions.append(
380-
f"Table `{table.table_name}`: {table.description}\n"
381-
)
401+
if table.schema_name:
402+
table_name = f"{table.schema_name}.{table.table_name}"
403+
else:
404+
table_name = table.table_name
405+
descriptions.append(f"Table `{table_name}`: {table.description}\n")
382406
for column in table.columns:
383407
if column.description is not None:
384408
descriptions.append(
@@ -555,6 +579,9 @@ def generate_response(
555579
)
556580
if not db_scan:
557581
raise ValueError("No scanned tables found for database")
582+
db_scan = SQLGenerator.filter_tables_by_schema(
583+
db_scan=db_scan, prompt=user_prompt
584+
)
558585
few_shot_examples, instructions = context_store.retrieve_context_for_question(
559586
user_prompt, number_of_samples=5
560587
)
@@ -658,6 +685,9 @@ def stream_response(
658685
)
659686
if not db_scan:
660687
raise ValueError("No scanned tables found for database")
688+
db_scan = SQLGenerator.filter_tables_by_schema(
689+
db_scan=db_scan, prompt=user_prompt
690+
)
661691
_, instructions = context_store.retrieve_context_for_question(
662692
user_prompt, number_of_samples=1
663693
)

0 commit comments

Comments
 (0)