Skip to content

Commit

Permalink
Merge branch 'main' into feat/acolyt-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
TxCorpi0x authored Feb 5, 2025
2 parents 6f34c5f + 66a0402 commit f6e8fef
Showing 1 changed file with 45 additions and 47 deletions.
92 changes: 45 additions & 47 deletions models/db_mig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,51 @@
from sqlmodel import SQLModel


async def add_column_if_not_exists(engine, table_name: str, column: Column) -> None:
async def add_column_if_not_exists(
conn, dialect, table_name: str, column: Column
) -> None:
"""Add a column to a table if it doesn't exist.
Args:
engine: SQLAlchemy engine
conn: SQLAlchemy conn
table_name: Name of the table
column: Column to add
"""
async with engine.connect() as conn:
# Use run_sync to perform inspection on the connection
def _get_columns(connection):
inspector = inspect(connection)
return [c["name"] for c in inspector.get_columns(table_name)]

columns = await conn.run_sync(_get_columns)

if column.name not in columns:
async with conn.begin():
# Build column definition
column_def = f"{column.name} {column.type.compile(engine.dialect)}"

# Add DEFAULT if specified
if column.default is not None:
if hasattr(column.default, "arg"):
default_value = column.default.arg
if not isinstance(default_value, Callable):
if isinstance(default_value, bool):
default_value = str(default_value).lower()
elif isinstance(default_value, str):
default_value = f"'{default_value}'"
elif isinstance(default_value, (list, dict)):
default_value = "'{}'"
column_def += f" DEFAULT {default_value}"

# Execute ALTER TABLE
await conn.execute(
text(f"ALTER TABLE {table_name} ADD COLUMN {column_def}")
)
logging.info(f"Added column {column.name} to table {table_name}")


async def update_table_schema(engine, model: Type[SQLModel]) -> None:

# Use run_sync to perform inspection on the connection
def _get_columns(connection):
inspector = inspect(connection)
return [c["name"] for c in inspector.get_columns(table_name)]

columns = await conn.run_sync(_get_columns)

if column.name not in columns:
# Build column definition
column_def = f"{column.name} {column.type.compile(dialect)}"

# Add DEFAULT if specified
if column.default is not None:
if hasattr(column.default, "arg"):
default_value = column.default.arg
if not isinstance(default_value, Callable):
if isinstance(default_value, bool):
default_value = str(default_value).lower()
elif isinstance(default_value, str):
default_value = f"'{default_value}'"
elif isinstance(default_value, (list, dict)):
default_value = "'{}'"
column_def += f" DEFAULT {default_value}"

# Execute ALTER TABLE
await conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_def}"))
logging.info(f"Added column {column.name} to table {table_name}")


async def update_table_schema(conn, dialect, model: Type[SQLModel]) -> None:
"""Update table schema by adding missing columns from the model.
Args:
engine: SQLAlchemy engine
conn: SQLAlchemy conn
model: SQLModel class to check for new columns
"""
if not hasattr(model, "__table__"):
Expand All @@ -61,7 +60,7 @@ async def update_table_schema(engine, model: Type[SQLModel]) -> None:
table_name = model.__tablename__
for name, column in model.__table__.columns.items():
if name != "id": # Skip primary key
await add_column_if_not_exists(engine, table_name, column)
await add_column_if_not_exists(conn, dialect, table_name, column)


async def safe_migrate(engine) -> None:
Expand All @@ -75,17 +74,16 @@ async def safe_migrate(engine) -> None:
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

# Get existing table metadata
metadata = MetaData()
async with engine.begin() as conn:
# Get existing table metadata
metadata = MetaData()
await conn.run_sync(metadata.reflect)

# Update schema for all SQLModel classes
for model in SQLModel.__subclasses__():
if hasattr(model, "__tablename__"):
table_name = model.__tablename__
if table_name in metadata.tables:
await update_table_schema(engine, model)
# Update schema for all SQLModel classes
for model in SQLModel.__subclasses__():
if hasattr(model, "__tablename__"):
table_name = model.__tablename__
if table_name in metadata.tables:
await update_table_schema(conn, engine.dialect, model)

logging.info("Database schema updated successfully")
except Exception as e:
Expand Down

0 comments on commit f6e8fef

Please sign in to comment.