Skip to content

Commit 846010d

Browse files
committed
DBs without schema should store None
1 parent a381ce9 commit 846010d

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

dataherald/sql_database/models/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class DatabaseConnection(BaseModel):
106106

107107
@classmethod
108108
def get_dialect(cls, input_string):
109-
pattern = r"([^:/]+):/+([^/]+)/?([^/]+)"
109+
pattern = r"([^:/]+)://"
110110
match = re.match(pattern, input_string)
111111
if not match:
112112
raise InvalidURIFormatError(f"Invalid URI format: {input_string}")

dataherald/sql_database/services/database_connection.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __init__(self, scanner: Scanner, storage: DB):
1717
self.scanner = scanner
1818
self.storage = storage
1919

20-
def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]:
20+
def get_current_schema(
21+
self, database_connection: DatabaseConnection
22+
) -> list[str] | None:
2123
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
2224
inspector = inspect(sql_database.engine)
2325
if inspector.default_schema_name and database_connection.dialect not in [
@@ -37,7 +39,7 @@ def get_current_schema(self, database_connection: DatabaseConnection) -> list[st
3739
match = re.search(pattern, str(sql_database.engine.url))
3840
if match:
3941
return [match.group(1)]
40-
return ["default"]
42+
return None
4143

4244
def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str:
4345
if dialect in ["snowflake"]:
@@ -99,6 +101,9 @@ def create(
99101
)
100102
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
101103
schemas_and_tables[schema] = sql_database.get_tables_and_views()
104+
else:
105+
sql_database = SQLDatabase.get_sql_engine(database_connection, True)
106+
schemas_and_tables[None] = sql_database.get_tables_and_views()
102107

103108
# Connect db
104109
db_connection_repository = DatabaseConnectionRepository(self.storage)

0 commit comments

Comments
 (0)