Skip to content

Commit 41a4c5d

Browse files
committedOct 26, 2024
🐛 fix MariaDB INSERT ON DUPLICATE KEY UPDATE
1 parent 3894f4e commit 41a4c5d

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed
 

‎src/sqlite3_to_mysql/mysql_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from packaging import version
99
from packaging.version import Version
1010

11-
1211
# Shamelessly copied from SQLAlchemy's dialects/mysql/__init__.py
1312
MYSQL_COLUMN_TYPES: t.Tuple[str, ...] = (
1413
"BIGINT",
@@ -113,7 +112,7 @@ def get_mysql_version(version_string: str) -> version.Version:
113112
def check_mysql_json_support(version_string: str) -> bool:
114113
"""Check for MySQL JSON support."""
115114
mysql_version: Version = get_mysql_version(version_string)
116-
if version_string.lower().endswith("-mariadb"):
115+
if "mariadb" in version_string.lower():
117116
if mysql_version.major >= 10 and mysql_version.minor >= 2 and mysql_version.micro >= 7:
118117
return True
119118
else:
@@ -124,6 +123,23 @@ def check_mysql_json_support(version_string: str) -> bool:
124123
return False
125124

126125

126+
def check_mysql_values_alias_support(version_string: str) -> bool:
127+
"""Check for VALUES alias support."""
128+
mysql_version: Version = get_mysql_version(version_string)
129+
if "mariadb" in version_string.lower():
130+
return False
131+
## Only MySQL 8.0.19 and later support VALUES alias
132+
if mysql_version.major >= 8:
133+
if mysql_version.major > 8:
134+
return True
135+
if mysql_version.minor > 0:
136+
return True
137+
if mysql_version.micro >= 19:
138+
return True
139+
return False
140+
return False
141+
142+
127143
def check_mysql_fulltext_support(version_string: str) -> bool:
128144
"""Check for FULLTEXT indexing support."""
129145
mysql_version: Version = get_mysql_version(version_string)

‎src/sqlite3_to_mysql/transporter.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
convert_timedelta,
3030
unicase_compare,
3131
)
32-
3332
from .mysql_utils import (
3433
MYSQL_BLOB_COLUMN_TYPES,
3534
MYSQL_COLUMN_TYPES,
@@ -40,6 +39,7 @@
4039
check_mysql_fulltext_support,
4140
check_mysql_json_support,
4241
safe_identifier_length,
42+
check_mysql_values_alias_support,
4343
)
4444
from .types import SQLite3toMySQLAttributes, SQLite3toMySQLParams
4545

@@ -723,21 +723,39 @@ def transfer(self) -> None:
723723
safe_identifier_length(column[0]) for column in self._sqlite_cur.description
724724
]
725725
if self._mysql_insert_method.upper() == "UPDATE":
726-
sql: str = (
727-
"""
728-
INSERT
729-
INTO `{table}` ({fields})
730-
VALUES ({placeholders}) AS `__new__`
731-
ON DUPLICATE KEY UPDATE {field_updates}
732-
""".format(
733-
table=safe_identifier_length(table["name"]),
734-
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
735-
placeholders=("%s, " * len(columns)).rstrip(" ,"),
736-
field_updates=("`{}`=`__new__`.`{}`, " * len(columns))
737-
.rstrip(" ,")
738-
.format(*list(chain.from_iterable((column, column) for column in columns))),
726+
if check_mysql_values_alias_support(self._mysql_version):
727+
sql: str = (
728+
"""
729+
INSERT
730+
INTO `{table}` ({fields})
731+
VALUES ({placeholders}) AS `__new__`
732+
ON DUPLICATE KEY UPDATE {field_updates}
733+
""".format(
734+
table=safe_identifier_length(table["name"]),
735+
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
736+
placeholders=("%s, " * len(columns)).rstrip(" ,"),
737+
field_updates=("`{}`=`__new__`.`{}`, " * len(columns))
738+
.rstrip(" ,")
739+
.format(*list(chain.from_iterable((column, column) for column in columns))),
740+
)
739741
)
740-
)
742+
else:
743+
sql: str = (
744+
"""
745+
INSERT
746+
INTO `{table}` ({fields})
747+
VALUES ({placeholders})
748+
ON DUPLICATE KEY UPDATE {field_updates}
749+
""".format(
750+
table=safe_identifier_length(table["name"]),
751+
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
752+
placeholders=("%s, " * len(columns)).rstrip(" ,"),
753+
field_updates=("`{}`=`{}`, " * len(columns))
754+
.rstrip(" ,")
755+
.format(*list(chain.from_iterable((column, column) for column in columns))),
756+
)
757+
)
758+
print(sql)
741759
else:
742760
sql = """
743761
INSERT {ignore}

0 commit comments

Comments
 (0)