Skip to content

Commit e164a6f

Browse files
committed
[DH-5733] Support schemas column to add a db connection
1 parent 828c64d commit e164a6f

File tree

8 files changed

+130
-21
lines changed

8 files changed

+130
-21
lines changed

Diff for: dataherald/api/fastapi.py

+5-21
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
SQLInjectionError,
7171
)
7272
from dataherald.sql_database.models.types import DatabaseConnection
73+
from dataherald.sql_database.services.database_connection import (
74+
DatabaseConnectionService,
75+
)
7376
from dataherald.types import (
7477
BaseLLM,
7578
CancelFineTuningRequest,
@@ -173,27 +176,9 @@ def create_database_connection(
173176
self, database_connection_request: DatabaseConnectionRequest
174177
) -> DatabaseConnectionResponse:
175178
try:
176-
db_connection = DatabaseConnection(
177-
alias=database_connection_request.alias,
178-
connection_uri=database_connection_request.connection_uri.strip(),
179-
path_to_credentials_file=database_connection_request.path_to_credentials_file,
180-
llm_api_key=database_connection_request.llm_api_key,
181-
use_ssh=database_connection_request.use_ssh,
182-
ssh_settings=database_connection_request.ssh_settings,
183-
file_storage=database_connection_request.file_storage,
184-
metadata=database_connection_request.metadata,
185-
)
186-
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
187-
188-
# Get tables and views and create table-descriptions as NOT_SCANNED
189-
db_connection_repository = DatabaseConnectionRepository(self.storage)
190-
191-
scanner_repository = TableDescriptionRepository(self.storage)
192179
scanner = self.system.instance(Scanner)
193-
194-
tables = sql_database.get_tables_and_views()
195-
db_connection = db_connection_repository.insert(db_connection)
196-
scanner.create_tables(tables, str(db_connection.id), scanner_repository)
180+
db_connection_service = DatabaseConnectionService(scanner, self.storage)
181+
db_connection = db_connection_service.create(database_connection_request)
197182
except Exception as e:
198183
# Encrypt sensible values
199184
fernet_encrypt = FernetEncrypt()
@@ -209,7 +194,6 @@ def create_database_connection(
209194
return error_response(
210195
e, database_connection_request.dict(), "invalid_database_connection"
211196
)
212-
213197
return DatabaseConnectionResponse(**db_connection.dict())
214198

215199
@override

Diff for: dataherald/db_scanner/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def create_tables(
3434
self,
3535
tables: list[str],
3636
db_connection_id: str,
37+
schema: str,
3738
repository: TableDescriptionRepository,
3839
metadata: dict = None,
3940
) -> None:

Diff for: dataherald/db_scanner/models/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TableDescriptionStatus(Enum):
3131
class TableDescription(BaseModel):
3232
id: str | None
3333
db_connection_id: str
34+
schema_name: str | None
3435
table_name: str
3536
description: str | None
3637
table_schema: str | None

Diff for: dataherald/db_scanner/sqlalchemy.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ def create_tables(
4444
self,
4545
tables: list[str],
4646
db_connection_id: str,
47+
schema: str,
4748
repository: TableDescriptionRepository,
4849
metadata: dict = None,
4950
) -> None:
5051
for table in tables:
5152
repository.save_table_info(
5253
TableDescription(
5354
db_connection_id=db_connection_id,
55+
schema_name=schema,
5456
table_name=table,
5557
status=TableDescriptionStatus.NOT_SCANNED.value,
5658
metadata=metadata,

Diff for: dataherald/sql_database/models/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class DatabaseConnection(BaseModel):
9696
dialect: SupportedDialects | None
9797
use_ssh: bool = False
9898
connection_uri: str | None
99+
schemas: list[str] | None
99100
path_to_credentials_file: str | None
100101
llm_api_key: str | None = None
101102
ssh_settings: SSHSettings | None = None

Diff for: dataherald/sql_database/services/__init__.py

Whitespace-only changes.
+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import re
2+
3+
from sqlalchemy import inspect
4+
5+
from dataherald.db import DB
6+
from dataherald.db_scanner import Scanner
7+
from dataherald.db_scanner.repository.base import TableDescriptionRepository
8+
from dataherald.repositories.database_connections import DatabaseConnectionRepository
9+
from dataherald.sql_database.base import SQLDatabase
10+
from dataherald.sql_database.models.types import DatabaseConnection
11+
from dataherald.types import DatabaseConnectionRequest
12+
from dataherald.utils.encrypt import FernetEncrypt
13+
14+
15+
class DatabaseConnectionService:
16+
def __init__(self, scanner: Scanner, storage: DB):
17+
self.scanner = scanner
18+
self.storage = storage
19+
20+
def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]:
21+
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
22+
inspector = inspect(sql_database.engine)
23+
if inspector.default_schema_name and database_connection.dialect not in [
24+
"mssql",
25+
"mysql",
26+
"clickhouse",
27+
"duckdb",
28+
]:
29+
return [inspector.default_schema_name]
30+
if database_connection.dialect == "bigquery":
31+
pattern = r"([^:/]+)://([^/]+)/([^/]+)(\?[^/]+)"
32+
match = re.match(pattern, str(sql_database.engine.url))
33+
if match:
34+
return [match.group(3)]
35+
elif database_connection.dialect == "databricks":
36+
pattern = r"&schema=([^&]*)"
37+
match = re.search(pattern, str(sql_database.engine.url))
38+
if match:
39+
return [match.group(1)]
40+
return ["default"]
41+
42+
def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str:
43+
if dialect in ["snowflake"]:
44+
pattern = r"([^:/]+)://([^:]+):([^@]+)@([^:/]+)(?::(\d+))?/([^/]+)"
45+
match = re.match(pattern, connection_uri)
46+
if match:
47+
return match.group(0)
48+
if dialect in ["bigquery"]:
49+
pattern = r"([^:/]+)://([^/]+)"
50+
match = re.match(pattern, connection_uri)
51+
if match:
52+
return match.group(0)
53+
elif dialect in ["databricks"]:
54+
pattern = r"&schema=[^&]*"
55+
return re.sub(pattern, "", connection_uri)
56+
elif dialect in ["postgresql"]:
57+
pattern = r"\?options=-csearch_path" r"=[^&]*"
58+
return re.sub(pattern, "", connection_uri)
59+
return connection_uri
60+
61+
def add_schema_in_uri(self, connection_uri: str, schema: str, dialect: str) -> str:
62+
connection_uri = self.remove_schema_in_uri(connection_uri, dialect)
63+
if dialect in ["snowflake", "bigquery"]:
64+
return f"{connection_uri}/{schema}"
65+
if dialect in ["databricks"]:
66+
return f"{connection_uri}&schema={schema}"
67+
if dialect in ["postgresql"]:
68+
return f"{connection_uri}?options=-csearch_path={schema}"
69+
return connection_uri
70+
71+
def create(
72+
self, database_connection_request: DatabaseConnectionRequest
73+
) -> DatabaseConnection:
74+
database_connection = DatabaseConnection(
75+
alias=database_connection_request.alias,
76+
connection_uri=database_connection_request.connection_uri.strip(),
77+
schemas=database_connection_request.schemas,
78+
path_to_credentials_file=database_connection_request.path_to_credentials_file,
79+
llm_api_key=database_connection_request.llm_api_key,
80+
use_ssh=database_connection_request.use_ssh,
81+
ssh_settings=database_connection_request.ssh_settings,
82+
file_storage=database_connection_request.file_storage,
83+
metadata=database_connection_request.metadata,
84+
)
85+
if not database_connection.schemas:
86+
database_connection.schemas = self.get_current_schema(database_connection)
87+
88+
schemas_and_tables = {}
89+
fernet_encrypt = FernetEncrypt()
90+
91+
if database_connection.schemas:
92+
for schema in database_connection.schemas:
93+
database_connection.connection_uri = fernet_encrypt.encrypt(
94+
self.add_schema_in_uri(
95+
database_connection_request.connection_uri.strip(),
96+
schema,
97+
str(database_connection.dialect),
98+
)
99+
)
100+
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
101+
schemas_and_tables[schema] = sql_database.get_tables_and_views()
102+
103+
# Connect db
104+
db_connection_repository = DatabaseConnectionRepository(self.storage)
105+
database_connection.connection_uri = fernet_encrypt.encrypt(
106+
self.remove_schema_in_uri(
107+
database_connection_request.connection_uri.strip(),
108+
str(database_connection.dialect),
109+
)
110+
)
111+
db_connection = db_connection_repository.insert(database_connection)
112+
113+
scanner_repository = TableDescriptionRepository(self.storage)
114+
# Add created tables
115+
for schema, tables in schemas_and_tables.items():
116+
self.scanner.create_tables(
117+
tables, str(db_connection.id), schema, scanner_repository
118+
)
119+
return DatabaseConnection(**db_connection.dict())

Diff for: dataherald/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class DatabaseConnectionRequest(BaseModel):
8787
alias: str
8888
use_ssh: bool = False
8989
connection_uri: str
90+
schemas: list[str] | None
9091
path_to_credentials_file: str | None
9192
llm_api_key: str | None
9293
ssh_settings: SSHSettings | None

0 commit comments

Comments
 (0)