Skip to content

Commit b4f57c4

Browse files
DH-5766/adding the validation to raise exception for queries without schema in multiple schema setting
1 parent b24d9c8 commit b4f57c4

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

Diff for: dataherald/context_store/default.py

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataherald.repositories.golden_sqls import GoldenSQLRepository
1414
from dataherald.repositories.instructions import InstructionRepository
1515
from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt
16+
from dataherald.utils.sql_utils import extract_the_schemas_from_sql
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
8687
f"Database connection not found, {record.db_connection_id}"
8788
)
8889

90+
if db_connection.schemas:
91+
schema_not_found = True
92+
used_schemas = extract_the_schemas_from_sql(record.sql)
93+
for schema in db_connection.schemas:
94+
if schema in used_schemas:
95+
schema_not_found = False
96+
break
97+
if schema_not_found:
98+
raise MalformedGoldenSQLError(
99+
f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}"
100+
)
101+
89102
prompt_text = record.prompt_text
90103
golden_sql = GoldenSQL(
91104
prompt_text=prompt_text,

Diff for: dataherald/utils/sql_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from sql_metadata import Parser
2+
3+
4+
def extract_the_schemas_from_sql(sql):
5+
table_names = Parser(sql).tables
6+
schemas = []
7+
for table_name in table_names:
8+
if "." in table_name:
9+
schema = table_name.split(".")[0]
10+
schemas.append(schema.strip())
11+
return schemas

0 commit comments

Comments
 (0)