Skip to content

Commit 2b63aa8

Browse files
DH-5765/add support multiple schema for finetuning
1 parent b4f57c4 commit 2b63aa8

File tree

4 files changed

+55
-3
lines changed

4 files changed

+55
-3
lines changed

Diff for: dataherald/api/fastapi.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@
9191
)
9292
from dataherald.utils.encrypt import FernetEncrypt
9393
from dataherald.utils.error_codes import error_response, stream_error_response
94+
from dataherald.utils.sql_utils import (
95+
filter_golden_records_based_on_schema,
96+
validate_finetuning_schema,
97+
)
9498

9599
logger = logging.getLogger(__name__)
96100

@@ -564,15 +568,14 @@ def create_finetuning_job(
564568
) -> Finetuning:
565569
try:
566570
db_connection_repository = DatabaseConnectionRepository(self.storage)
567-
568571
db_connection = db_connection_repository.find_by_id(
569572
fine_tuning_request.db_connection_id
570573
)
571574
if not db_connection:
572575
raise DatabaseConnectionNotFoundError(
573576
f"Database connection not found, {fine_tuning_request.db_connection_id}"
574577
)
575-
578+
validate_finetuning_schema(fine_tuning_request, db_connection)
576579
golden_sqls_repository = GoldenSQLRepository(self.storage)
577580
golden_sqls = []
578581
if fine_tuning_request.golden_sqls:
@@ -593,6 +596,9 @@ def create_finetuning_job(
593596
raise GoldenSQLNotFoundError(
594597
f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}"
595598
)
599+
golden_sqls = filter_golden_records_based_on_schema(
600+
golden_sqls, fine_tuning_request.schemas
601+
)
596602
default_base_llm = BaseLLM(
597603
model_provider="openai",
598604
model_name="gpt-3.5-turbo-1106",
@@ -606,6 +612,7 @@ def create_finetuning_job(
606612
model = model_repository.insert(
607613
Finetuning(
608614
db_connection_id=fine_tuning_request.db_connection_id,
615+
schemas=fine_tuning_request.schemas,
609616
alias=fine_tuning_request.alias
610617
if fine_tuning_request.alias
611618
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",

Diff for: dataherald/finetuning/openai_finetuning.py

+7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def map_finetuning_status(status: str) -> str:
6969
return FineTuningStatus.QUEUED.value
7070
return mapped_statuses[status]
7171

72+
@staticmethod
73+
def _filter_tables_by_schema(db_scan: List[TableDescription], schemas: List[str]):
74+
if schemas:
75+
return [table for table in db_scan if table.schema_name in schemas]
76+
return db_scan
77+
7278
def format_columns(
7379
self, table: TableDescription, top_k: int = CATEGORICAL_COLUMNS_THRESHOLD
7480
) -> str:
@@ -197,6 +203,7 @@ def create_fintuning_dataset(self):
197203
"status": TableDescriptionStatus.SCANNED.value,
198204
}
199205
)
206+
db_scan = self._filter_tables_by_schema(db_scan, self.fine_tuning_model.schemas)
200207
golden_sqls_repository = GoldenSQLRepository(self.storage)
201208
finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
202209
model_repository = FinetuningsRepository(self.storage)

Diff for: dataherald/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class Finetuning(BaseModel):
150150
id: str | None = None
151151
alias: str | None = None
152152
db_connection_id: str | None = None
153+
schemas: list[str] | None
153154
status: str = "QUEUED"
154155
error: str | None = None
155156
base_llm: BaseLLM | None = None
@@ -163,6 +164,7 @@ class Finetuning(BaseModel):
163164

164165
class FineTuningRequest(BaseModel):
165166
db_connection_id: str
167+
schemas: list[str] | None
166168
alias: str | None = None
167169
base_llm: BaseLLM | None = None
168170
golden_sqls: list[str] | None = None

Diff for: dataherald/utils/sql_utils.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,47 @@
11
from sql_metadata import Parser
22

3+
from dataherald.sql_database.models.types import DatabaseConnection
4+
from dataherald.sql_database.services.database_connection import SchemaNotSupportedError
5+
from dataherald.types import FineTuningRequest, GoldenSQL
36

4-
def extract_the_schemas_from_sql(sql):
7+
8+
def extract_the_schemas_from_sql(sql: str) -> list[str]:
59
table_names = Parser(sql).tables
610
schemas = []
711
for table_name in table_names:
812
if "." in table_name:
913
schema = table_name.split(".")[0]
1014
schemas.append(schema.strip())
1115
return schemas
16+
17+
18+
def filter_golden_records_based_on_schema(
19+
golden_sqls: list[GoldenSQL], schemas: list[str]
20+
) -> list[GoldenSQL]:
21+
filtered_records = []
22+
if not schemas:
23+
return golden_sqls
24+
for record in golden_sqls:
25+
used_schemas = extract_the_schemas_from_sql(record.sql)
26+
for schema in schemas:
27+
if schema in used_schemas:
28+
filtered_records.append(record)
29+
break
30+
return filtered_records
31+
32+
33+
def validate_finetuning_schema(
34+
finetuning_request: FineTuningRequest, db_connection: DatabaseConnection
35+
):
36+
if finetuning_request.schemas:
37+
if not db_connection.schemas:
38+
raise SchemaNotSupportedError(
39+
"Schema not supported for this db",
40+
description=f"The {db_connection.id} db doesn't have schemas",
41+
)
42+
for schema in finetuning_request.schemas:
43+
if schema not in db_connection.schemas:
44+
raise SchemaNotSupportedError(
45+
f"Schema {schema} not supported for this db",
46+
description=f"The {db_connection.dialect} dialect doesn't support schema {schema}",
47+
)

0 commit comments

Comments
 (0)