Skip to content

Commit d4d6f4e

Browse files
[DH-5733] Support schemas column to add a db connection (#466)
* [DH-5733] Support schemas column to add a db connection * DBs without schema should store None * Add ids in sync-schemas endpoint * Support multi-schemas for refresh endpoint * Add schema not support error exception * Add documentation for multi-schemas * Fix sync schema method * Sync-schemas endpoint let adding ids from different db connection * Fix refresh endpoint * Fix table-description storage * Fix schema_name filter in table-description repository * DH-5735/add support for multiple schemas for agents * DH-5766/adding the validation to raise exception for queries without schema in multiple schema setting * DH-5765/add support multiple schema for finetuning --------- Co-authored-by: mohammadrezapourreza <m1378.prz@gmail.com>
1 parent fbd96ea commit d4d6f4e

32 files changed

+594
-198
lines changed

Diff for: README.md

+22-3
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,24 @@ curl -X 'POST' \
178178
}'
179179
```
180180

181+
##### Connecting multi-schemas
182+
You can connect many schemas using one db connection if you want to create SQL joins between schemas.
183+
Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature.
184+
To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this:
185+
186+
```
187+
curl -X 'POST' \
188+
'<host>/api/v1/database-connections' \
189+
-H 'accept: application/json' \
190+
-H 'Content-Type: application/json' \
191+
-d '{
192+
"alias": "my_db_alias",
193+
"use_ssh": false,
194+
"connection_uri": snowflake://<user>:<password>@<organization>-<account-name>/<database>",
195+
"schemas": ["schema_1", "schema_2", ...]
196+
}'
197+
```
198+
181199
##### Connecting to supported Data warehouses and using SSH
182200
You can find the details on how to connect to the supported data warehouses in the [docs](https://dataherald.readthedocs.io/en/latest/api.create_database_connection.html)
183201

@@ -194,7 +212,8 @@ While only the Database scan part is required to start generating SQL, adding ve
194212
#### Scanning the Database
195213
The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM.
196214
In addition, it retrieves logs, which consist of historical queries associated with each database table. These records are then stored within the query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user.
197-
db_connection_id is the id of the database connection you want to scan, which is returned when you create a database connection.
215+
The db_connection_id param is the id of the database connection you want to scan, which is returned when you create a database connection.
216+
The ids param is the table_description_id that you want to scan.
198217
You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below
199218

200219

@@ -205,11 +224,11 @@ curl -X 'POST' \
205224
-H 'Content-Type: application/json' \
206225
-d '{
207226
"db_connection_id": "db_connection_id",
208-
"table_names": ["table_name"]
227+
"ids": ["<table_description_id_1>", "<table_description_id_2>", ...]
209228
}'
210229
```
211230

212-
Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. Therefore while it is possible to trigger a scan on the entire DB by not specifying the `table_names`, we recommend against it for large databases.
231+
Since the endpoint identifies low cardinality columns (and their values) it can take time to complete.
213232

214233
#### Get logs per db connection
215234
Once a database was scanned you can use this endpoint to retrieve the tables logs

Diff for: dataherald/api/fastapi.py

+67-65
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
SQLInjectionError,
7171
)
7272
from dataherald.sql_database.models.types import DatabaseConnection
73+
from dataherald.sql_database.services.database_connection import (
74+
DatabaseConnectionService,
75+
)
7376
from dataherald.types import (
7477
BaseLLM,
7578
CancelFineTuningRequest,
@@ -88,17 +91,20 @@
8891
)
8992
from dataherald.utils.encrypt import FernetEncrypt
9093
from dataherald.utils.error_codes import error_response, stream_error_response
94+
from dataherald.utils.sql_utils import (
95+
filter_golden_records_based_on_schema,
96+
validate_finetuning_schema,
97+
)
9198

9299
logger = logging.getLogger(__name__)
93100

94101
MAX_ROWS_TO_CREATE_CSV_FILE = 50
95102

96103

97-
def async_scanning(scanner, database, scanner_request, storage):
104+
def async_scanning(scanner, database, table_descriptions, storage):
98105
scanner.scan(
99106
database,
100-
scanner_request.db_connection_id,
101-
scanner_request.table_names,
107+
table_descriptions,
102108
TableDescriptionRepository(storage),
103109
QueryHistoryRepository(storage),
104110
)
@@ -130,70 +136,52 @@ def scan_db(
130136
self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks
131137
) -> list[TableDescriptionResponse]:
132138
"""Takes a db_connection_id and scan all the tables columns"""
133-
try:
134-
db_connection_repository = DatabaseConnectionRepository(self.storage)
135-
136-
db_connection = db_connection_repository.find_by_id(
137-
scanner_request.db_connection_id
138-
)
139+
scanner_repository = TableDescriptionRepository(self.storage)
140+
data = {}
141+
for id in scanner_request.ids:
142+
table_description = scanner_repository.find_by_id(id)
143+
if not table_description:
144+
raise Exception("Table description not found")
145+
if table_description.db_connection_id not in data.keys():
146+
data[table_description.db_connection_id] = {}
147+
if (
148+
table_description.schema_name
149+
not in data[table_description.db_connection_id].keys()
150+
):
151+
data[table_description.db_connection_id][
152+
table_description.schema_name
153+
] = []
154+
data[table_description.db_connection_id][
155+
table_description.schema_name
156+
].append(table_description)
139157

140-
if not db_connection:
141-
raise DatabaseConnectionNotFoundError(
142-
f"Database connection {scanner_request.db_connection_id} not found"
158+
db_connection_repository = DatabaseConnectionRepository(self.storage)
159+
scanner = self.system.instance(Scanner)
160+
rows = scanner.synchronizing(
161+
scanner_request,
162+
TableDescriptionRepository(self.storage),
163+
)
164+
database_connection_service = DatabaseConnectionService(scanner, self.storage)
165+
for db_connection_id, schemas_and_table_descriptions in data.items():
166+
for schema, table_descriptions in schemas_and_table_descriptions.items():
167+
db_connection = db_connection_repository.find_by_id(db_connection_id)
168+
database = database_connection_service.get_sql_database(
169+
db_connection, schema
143170
)
144171

145-
database = SQLDatabase.get_sql_engine(db_connection, True)
146-
all_tables = database.get_tables_and_views()
147-
148-
if scanner_request.table_names:
149-
for table in scanner_request.table_names:
150-
if table not in all_tables:
151-
raise HTTPException(
152-
status_code=404,
153-
detail=f"Table named: {table} doesn't exist",
154-
) # noqa: B904
155-
else:
156-
scanner_request.table_names = all_tables
157-
158-
scanner = self.system.instance(Scanner)
159-
rows = scanner.synchronizing(
160-
scanner_request,
161-
TableDescriptionRepository(self.storage),
162-
)
163-
except Exception as e:
164-
return error_response(e, scanner_request.dict(), "invalid_database_sync")
165-
166-
background_tasks.add_task(
167-
async_scanning, scanner, database, scanner_request, self.storage
168-
)
172+
background_tasks.add_task(
173+
async_scanning, scanner, database, table_descriptions, self.storage
174+
)
169175
return [TableDescriptionResponse(**row.dict()) for row in rows]
170176

171177
@override
172178
def create_database_connection(
173179
self, database_connection_request: DatabaseConnectionRequest
174180
) -> DatabaseConnectionResponse:
175181
try:
176-
db_connection = DatabaseConnection(
177-
alias=database_connection_request.alias,
178-
connection_uri=database_connection_request.connection_uri.strip(),
179-
path_to_credentials_file=database_connection_request.path_to_credentials_file,
180-
llm_api_key=database_connection_request.llm_api_key,
181-
use_ssh=database_connection_request.use_ssh,
182-
ssh_settings=database_connection_request.ssh_settings,
183-
file_storage=database_connection_request.file_storage,
184-
metadata=database_connection_request.metadata,
185-
)
186-
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
187-
188-
# Get tables and views and create table-descriptions as NOT_SCANNED
189-
db_connection_repository = DatabaseConnectionRepository(self.storage)
190-
191-
scanner_repository = TableDescriptionRepository(self.storage)
192182
scanner = self.system.instance(Scanner)
193-
194-
tables = sql_database.get_tables_and_views()
195-
db_connection = db_connection_repository.insert(db_connection)
196-
scanner.create_tables(tables, str(db_connection.id), scanner_repository)
183+
db_connection_service = DatabaseConnectionService(scanner, self.storage)
184+
db_connection = db_connection_service.create(database_connection_request)
197185
except Exception as e:
198186
# Encrypt sensible values
199187
fernet_encrypt = FernetEncrypt()
@@ -209,7 +197,6 @@ def create_database_connection(
209197
return error_response(
210198
e, database_connection_request.dict(), "invalid_database_connection"
211199
)
212-
213200
return DatabaseConnectionResponse(**db_connection.dict())
214201

215202
@override
@@ -220,18 +207,30 @@ def refresh_table_description(
220207
db_connection = db_connection_repository.find_by_id(
221208
refresh_table_description.db_connection_id
222209
)
223-
210+
scanner = self.system.instance(Scanner)
211+
database_connection_service = DatabaseConnectionService(scanner, self.storage)
224212
try:
225-
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
226-
tables = sql_database.get_tables_and_views()
213+
data = {}
214+
if db_connection.schemas:
215+
for schema in db_connection.schemas:
216+
sql_database = database_connection_service.get_sql_database(
217+
db_connection, schema
218+
)
219+
if schema not in data.keys():
220+
data[schema] = []
221+
data[schema] = sql_database.get_tables_and_views()
222+
else:
223+
sql_database = database_connection_service.get_sql_database(
224+
db_connection
225+
)
226+
data[None] = sql_database.get_tables_and_views()
227227

228-
# Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED
229228
scanner_repository = TableDescriptionRepository(self.storage)
230-
scanner = self.system.instance(Scanner)
229+
231230
return [
232231
TableDescriptionResponse(**record.dict())
233232
for record in scanner.refresh_tables(
234-
tables, str(db_connection.id), scanner_repository
233+
data, str(db_connection.id), scanner_repository
235234
)
236235
]
237236
except Exception as e:
@@ -569,15 +568,14 @@ def create_finetuning_job(
569568
) -> Finetuning:
570569
try:
571570
db_connection_repository = DatabaseConnectionRepository(self.storage)
572-
573571
db_connection = db_connection_repository.find_by_id(
574572
fine_tuning_request.db_connection_id
575573
)
576574
if not db_connection:
577575
raise DatabaseConnectionNotFoundError(
578576
f"Database connection not found, {fine_tuning_request.db_connection_id}"
579577
)
580-
578+
validate_finetuning_schema(fine_tuning_request, db_connection)
581579
golden_sqls_repository = GoldenSQLRepository(self.storage)
582580
golden_sqls = []
583581
if fine_tuning_request.golden_sqls:
@@ -598,6 +596,9 @@ def create_finetuning_job(
598596
raise GoldenSQLNotFoundError(
599597
f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}"
600598
)
599+
golden_sqls = filter_golden_records_based_on_schema(
600+
golden_sqls, fine_tuning_request.schemas
601+
)
601602
default_base_llm = BaseLLM(
602603
model_provider="openai",
603604
model_name="gpt-3.5-turbo-1106",
@@ -611,6 +612,7 @@ def create_finetuning_job(
611612
model = model_repository.insert(
612613
Finetuning(
613614
db_connection_id=fine_tuning_request.db_connection_id,
615+
schemas=fine_tuning_request.schemas,
614616
alias=fine_tuning_request.alias
615617
if fine_tuning_request.alias
616618
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",

Diff for: dataherald/api/types/requests.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
class PromptRequest(BaseModel):
88
text: str
99
db_connection_id: str
10+
schemas: list[str] | None
1011
metadata: dict | None
1112

1213

Diff for: dataherald/api/types/responses.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def created_at_as_string(cls, v):
2525
class PromptResponse(BaseResponse):
2626
text: str
2727
db_connection_id: str
28+
schemas: list[str] | None
2829

2930

3031
class SQLGenerationResponse(BaseResponse):

Diff for: dataherald/context_store/default.py

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataherald.repositories.golden_sqls import GoldenSQLRepository
1414
from dataherald.repositories.instructions import InstructionRepository
1515
from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt
16+
from dataherald.utils.sql_utils import extract_the_schemas_from_sql
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
8687
f"Database connection not found, {record.db_connection_id}"
8788
)
8889

90+
if db_connection.schemas:
91+
schema_not_found = True
92+
used_schemas = extract_the_schemas_from_sql(record.sql)
93+
for schema in db_connection.schemas:
94+
if schema in used_schemas:
95+
schema_not_found = False
96+
break
97+
if schema_not_found:
98+
raise MalformedGoldenSQLError(
99+
f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}"
100+
)
101+
89102
prompt_text = record.prompt_text
90103
golden_sql = GoldenSQL(
91104
prompt_text=prompt_text,

Diff for: dataherald/db_scanner/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ class Scanner(Component, ABC):
1414
def scan(
1515
self,
1616
db_engine: SQLDatabase,
17-
db_connection_id: str,
18-
table_names: list[str] | None,
17+
table_descriptions: list[TableDescription],
1918
repository: TableDescriptionRepository,
2019
query_history_repository: QueryHistoryRepository,
2120
) -> None:
@@ -34,6 +33,7 @@ def create_tables(
3433
self,
3534
tables: list[str],
3635
db_connection_id: str,
36+
schema: str,
3737
repository: TableDescriptionRepository,
3838
metadata: dict = None,
3939
) -> None:
@@ -42,7 +42,7 @@ def create_tables(
4242
@abstractmethod
4343
def refresh_tables(
4444
self,
45-
tables: list[str],
45+
schemas_and_tables: dict[str, list],
4646
db_connection_id: str,
4747
repository: TableDescriptionRepository,
4848
metadata: dict = None,

Diff for: dataherald/db_scanner/models/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TableDescriptionStatus(Enum):
3131
class TableDescription(BaseModel):
3232
id: str | None
3333
db_connection_id: str
34+
schema_name: str | None
3435
table_name: str
3536
description: str | None
3637
table_schema: str | None

Diff for: dataherald/db_scanner/repository/base.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,18 @@ def save_table_info(self, table_info: TableDescription) -> TableDescription:
5353
table_info_dict = {
5454
k: v for k, v in table_info_dict.items() if v is not None and v != []
5555
}
56+
57+
query = {
58+
"db_connection_id": table_info_dict["db_connection_id"],
59+
"table_name": table_info_dict["table_name"],
60+
}
61+
if "schema_name" in table_info_dict:
62+
query["schema_name"] = table_info_dict["schema_name"]
63+
5664
table_info.id = str(
5765
self.storage.update_or_create(
5866
DB_COLLECTION,
59-
{
60-
"db_connection_id": table_info_dict["db_connection_id"],
61-
"table_name": table_info_dict["table_name"],
62-
},
67+
query,
6368
table_info_dict,
6469
)
6570
)

0 commit comments

Comments
 (0)