Skip to content

Commit f6d617e

Browse files
committed
Add ids in sync-schemas endpoint
1 parent 846010d commit f6d617e

File tree

5 files changed

+59
-55
lines changed

5 files changed

+59
-55
lines changed

Diff for: dataherald/api/fastapi.py

+26-35
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,10 @@
9797
MAX_ROWS_TO_CREATE_CSV_FILE = 50
9898

9999

100-
def async_scanning(scanner, database, scanner_request, storage):
100+
def async_scanning(scanner, database, table_descriptions, storage):
101101
scanner.scan(
102102
database,
103-
scanner_request.db_connection_id,
104-
scanner_request.table_names,
103+
table_descriptions,
105104
TableDescriptionRepository(storage),
106105
QueryHistoryRepository(storage),
107106
)
@@ -133,43 +132,35 @@ def scan_db(
133132
self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks
134133
) -> list[TableDescriptionResponse]:
135134
"""Takes a db_connection_id and scan all the tables columns"""
136-
try:
137-
db_connection_repository = DatabaseConnectionRepository(self.storage)
135+
scanner_repository = TableDescriptionRepository(self.storage)
136+
data = {}
137+
for id in scanner_request.ids:
138+
table_description = scanner_repository.find_by_id(id)
139+
if not table_description:
140+
raise Exception("Table description not found")
141+
if table_description.schema_name not in data.keys():
142+
data[table_description.schema_name] = []
143+
data[table_description.schema_name].append(table_description)
138144

145+
db_connection_repository = DatabaseConnectionRepository(self.storage)
146+
scanner = self.system.instance(Scanner)
147+
database_connection_service = DatabaseConnectionService(scanner, self.storage)
148+
for schema, table_descriptions in data.items():
139149
db_connection = db_connection_repository.find_by_id(
140-
scanner_request.db_connection_id
150+
table_descriptions[0].db_connection_id
141151
)
142-
143-
if not db_connection:
144-
raise DatabaseConnectionNotFoundError(
145-
f"Database connection {scanner_request.db_connection_id} not found"
146-
)
147-
148-
database = SQLDatabase.get_sql_engine(db_connection, True)
149-
all_tables = database.get_tables_and_views()
150-
151-
if scanner_request.table_names:
152-
for table in scanner_request.table_names:
153-
if table not in all_tables:
154-
raise HTTPException(
155-
status_code=404,
156-
detail=f"Table named: {table} doesn't exist",
157-
) # noqa: B904
158-
else:
159-
scanner_request.table_names = all_tables
160-
161-
scanner = self.system.instance(Scanner)
162-
rows = scanner.synchronizing(
163-
scanner_request,
164-
TableDescriptionRepository(self.storage),
152+
database = database_connection_service.get_sql_database(
153+
db_connection, schema
165154
)
166-
except Exception as e:
167-
return error_response(e, scanner_request.dict(), "invalid_database_sync")
168155

169-
background_tasks.add_task(
170-
async_scanning, scanner, database, scanner_request, self.storage
171-
)
172-
return [TableDescriptionResponse(**row.dict()) for row in rows]
156+
background_tasks.add_task(
157+
async_scanning, scanner, database, table_descriptions, self.storage
158+
)
159+
return [
160+
TableDescriptionResponse(**row.dict())
161+
for _, table_descriptions in data.items()
162+
for row in table_descriptions
163+
]
173164

174165
@override
175166
def create_database_connection(

Diff for: dataherald/db_scanner/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ class Scanner(Component, ABC):
1414
def scan(
1515
self,
1616
db_engine: SQLDatabase,
17-
db_connection_id: str,
18-
table_names: list[str] | None,
17+
table_descriptions: list[TableDescription],
1918
repository: TableDescriptionRepository,
2019
query_history_repository: QueryHistoryRepository,
2120
) -> None:

Diff for: dataherald/db_scanner/sqlalchemy.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ def scan_single_table(
278278
def scan(
279279
self,
280280
db_engine: SQLDatabase,
281-
db_connection_id: str,
282-
table_names: list[str] | None,
281+
table_descriptions: list[TableDescription],
283282
repository: TableDescriptionRepository,
284283
query_history_repository: QueryHistoryRepository,
285284
) -> None:
@@ -295,32 +294,24 @@ def scan(
295294
if db_engine.engine.dialect.name in services.keys():
296295
scanner_service = services[db_engine.engine.dialect.name]()
297296

298-
inspector = inspect(db_engine.engine)
297+
inspect(db_engine.engine)
299298
meta = MetaData(bind=db_engine.engine)
300299
MetaData.reflect(meta, views=True)
301-
tables = inspector.get_table_names() + inspector.get_view_names()
302-
if table_names:
303-
table_names = [table.lower() for table in table_names]
304-
tables = [
305-
table for table in tables if table and table.lower() in table_names
306-
]
307-
if len(tables) == 0:
308-
raise ValueError("No table found")
309300

310-
for table in tables:
301+
for table in table_descriptions:
311302
try:
312303
self.scan_single_table(
313304
meta=meta,
314-
table=table,
305+
table=table.table_name,
315306
db_engine=db_engine,
316-
db_connection_id=db_connection_id,
307+
db_connection_id=table.db_connection_id,
317308
repository=repository,
318309
scanner_service=scanner_service,
319310
)
320311
except Exception as e:
321312
repository.save_table_info(
322313
TableDescription(
323-
db_connection_id=db_connection_id,
314+
db_connection_id=table.db_connection_id,
324315
table_name=table,
325316
status=TableDescriptionStatus.FAILED.value,
326317
error_message=f"{e}",
@@ -329,7 +320,7 @@ def scan(
329320
try:
330321
logger.info(f"Get logs table: {table}")
331322
query_history = scanner_service.get_logs(
332-
table, db_engine, db_connection_id
323+
table.table_name, db_engine, table.db_connection_id
333324
)
334325
if len(query_history) > 0:
335326
for query in query_history:

Diff for: dataherald/sql_database/services/database_connection.py

+14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ def __init__(self, scanner: Scanner, storage: DB):
1717
self.scanner = scanner
1818
self.storage = storage
1919

20+
def get_sql_database(
21+
self, database_connection: DatabaseConnection, schema: str = None
22+
) -> SQLDatabase:
23+
fernet_encrypt = FernetEncrypt()
24+
if schema:
25+
database_connection.connection_uri = fernet_encrypt.encrypt(
26+
self.add_schema_in_uri(
27+
fernet_encrypt.decrypt(database_connection.connection_uri),
28+
schema,
29+
database_connection.dialect.value,
30+
)
31+
)
32+
return SQLDatabase.get_sql_engine(database_connection, True)
33+
2034
def get_current_schema(
2135
self, database_connection: DatabaseConnection
2236
) -> list[str] | None:

Diff for: dataherald/types.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,19 @@ class SupportedDatabase(Enum):
7878
BIGQUERY = "BIGQUERY"
7979

8080

81-
class ScannerRequest(DBConnectionValidation):
82-
table_names: list[str] | None
81+
class ScannerRequest(BaseModel):
82+
ids: list[str] | None
8383
metadata: dict | None
8484

85+
@validator("ids")
86+
def ids_validation(cls, ids: list = None):
87+
try:
88+
for id in ids:
89+
ObjectId(id)
90+
except InvalidId:
91+
raise ValueError("Must be a valid ObjectId") # noqa: B904
92+
return ids
93+
8594

8695
class DatabaseConnectionRequest(BaseModel):
8796
alias: str

0 commit comments

Comments
 (0)