-
Notifications
You must be signed in to change notification settings - Fork 575
/
Copy pathdb_mig.py
91 lines (71 loc) · 3.08 KB
/
db_mig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Database migration utilities."""
import logging
from typing import Callable, Type
from sqlalchemy import Column, MetaData, inspect, text
from sqlmodel import SQLModel
async def add_column_if_not_exists(
conn, dialect, table_name: str, column: Column
) -> None:
"""Add a column to a table if it doesn't exist.
Args:
conn: SQLAlchemy conn
table_name: Name of the table
column: Column to add
"""
# Use run_sync to perform inspection on the connection
def _get_columns(connection):
inspector = inspect(connection)
return [c["name"] for c in inspector.get_columns(table_name)]
columns = await conn.run_sync(_get_columns)
if column.name not in columns:
# Build column definition
column_def = f"{column.name} {column.type.compile(dialect)}"
# Add DEFAULT if specified
if column.default is not None:
if hasattr(column.default, "arg"):
default_value = column.default.arg
if not isinstance(default_value, Callable):
if isinstance(default_value, bool):
default_value = str(default_value).lower()
elif isinstance(default_value, str):
default_value = f"'{default_value}'"
elif isinstance(default_value, (list, dict)):
default_value = "'{}'"
column_def += f" DEFAULT {default_value}"
# Execute ALTER TABLE
await conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_def}"))
logging.info(f"Added column {column.name} to table {table_name}")
async def update_table_schema(conn, dialect, model: Type[SQLModel]) -> None:
"""Update table schema by adding missing columns from the model.
Args:
conn: SQLAlchemy conn
model: SQLModel class to check for new columns
"""
if not hasattr(model, "__table__"):
return
table_name = model.__tablename__
for name, column in model.__table__.columns.items():
if name != "id": # Skip primary key
await add_column_if_not_exists(conn, dialect, table_name, column)
async def safe_migrate(engine) -> None:
"""Safely migrate all SQLModel tables by adding new columns.
Args:
engine: SQLAlchemy engine
"""
try:
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# Get existing table metadata
metadata = MetaData()
await conn.run_sync(metadata.reflect)
# Update schema for all SQLModel classes
for model in SQLModel.__subclasses__():
if hasattr(model, "__tablename__"):
table_name = model.__tablename__
if table_name in metadata.tables:
await update_table_schema(conn, engine.dialect, model)
logging.info("Database schema updated successfully")
except Exception as e:
logging.error(f"Error updating database schema: {str(e)}")
raise