Skip to content

Commit 3fc65aa

Browse files
techousemackuba
andauthoredOct 26, 2024
🐛 fix typo in --mysql-insert-method parameter (#130)
* fixed mysql insert method parameter * 🐛 fix MariaDB INSERT ON DUPLICATE KEY UPDATE --------- Co-authored-by: Kuba Suder <jakub.suder@gmail.com>
1 parent d5e5626 commit 3fc65aa

File tree

4 files changed

+127
-33
lines changed

4 files changed

+127
-33
lines changed
 

‎src/sqlite3_to_mysql/mysql_utils.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from mysql.connector import CharacterSet
77
from mysql.connector.charsets import MYSQL_CHARACTER_SETS
88
from packaging import version
9-
from packaging.version import Version
109

1110

1211
# Shamelessly copied from SQLAlchemy's dialects/mysql/__init__.py
@@ -112,30 +111,32 @@ def get_mysql_version(version_string: str) -> version.Version:
112111

113112
def check_mysql_json_support(version_string: str) -> bool:
114113
"""Check for MySQL JSON support."""
115-
mysql_version: Version = get_mysql_version(version_string)
116-
if version_string.lower().endswith("-mariadb"):
117-
if mysql_version.major >= 10 and mysql_version.minor >= 2 and mysql_version.micro >= 7:
118-
return True
119-
else:
120-
if mysql_version.major >= 8:
121-
return True
122-
if mysql_version.minor >= 7 and mysql_version.micro >= 8:
123-
return True
124-
return False
114+
mysql_version: version.Version = get_mysql_version(version_string)
115+
if "-mariadb" in version_string.lower():
116+
return mysql_version >= version.parse("10.2.7")
117+
return mysql_version >= version.parse("5.7.8")
118+
119+
120+
def check_mysql_values_alias_support(version_string: str) -> bool:
121+
"""Check for VALUES alias support.
122+
123+
Returns:
124+
bool: True if VALUES alias is supported (MySQL 8.0.19+), False for MariaDB
125+
or older MySQL versions.
126+
"""
127+
mysql_version: version.Version = get_mysql_version(version_string)
128+
if "-mariadb" in version_string.lower():
129+
return False
130+
# Only MySQL 8.0.19 and later support VALUES alias
131+
return mysql_version >= version.parse("8.0.19")
125132

126133

127134
def check_mysql_fulltext_support(version_string: str) -> bool:
128135
"""Check for FULLTEXT indexing support."""
129-
mysql_version: Version = get_mysql_version(version_string)
130-
if version_string.lower().endswith("-mariadb"):
131-
if mysql_version.major >= 10 and mysql_version.minor >= 0 and mysql_version.micro >= 5:
132-
return True
133-
else:
134-
if mysql_version.major >= 8:
135-
return True
136-
if mysql_version.minor >= 6:
137-
return True
138-
return False
136+
mysql_version: version.Version = get_mysql_version(version_string)
137+
if "-mariadb" in version_string.lower():
138+
return mysql_version >= version.parse("10.0.5")
139+
return mysql_version >= version.parse("5.6.0")
139140

140141

141142
def safe_identifier_length(identifier_name: str, max_length: int = 64) -> str:

‎src/sqlite3_to_mysql/transporter.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
MYSQL_TEXT_COLUMN_TYPES_WITH_JSON,
4040
check_mysql_fulltext_support,
4141
check_mysql_json_support,
42+
check_mysql_values_alias_support,
4243
safe_identifier_length,
4344
)
4445
from .types import SQLite3toMySQLAttributes, SQLite3toMySQLParams
@@ -88,7 +89,7 @@ def __init__(self, **kwargs: tx.Unpack[SQLite3toMySQLParams]):
8889

8990
self._mysql_database = kwargs.get("mysql_database", "transfer") or "transfer"
9091

91-
self._mysql_insert_method = str(kwargs.get("mysql_integer_type", "IGNORE")).upper()
92+
self._mysql_insert_method = str(kwargs.get("mysql_insert_method", "IGNORE")).upper()
9293
if self._mysql_insert_method not in MYSQL_INSERT_METHOD:
9394
self._mysql_insert_method = "IGNORE"
9495

@@ -722,21 +723,26 @@ def transfer(self) -> None:
722723
columns: t.List[str] = [
723724
safe_identifier_length(column[0]) for column in self._sqlite_cur.description
724725
]
726+
sql: str
725727
if self._mysql_insert_method.upper() == "UPDATE":
726-
sql: str = (
727-
"""
728+
sql = """
728729
INSERT
729730
INTO `{table}` ({fields})
730-
VALUES ({placeholders}) AS `__new__`
731+
{values_clause}
731732
ON DUPLICATE KEY UPDATE {field_updates}
732733
""".format(
733-
table=safe_identifier_length(table["name"]),
734-
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
735-
placeholders=("%s, " * len(columns)).rstrip(" ,"),
736-
field_updates=("`{}`=`__new__`.`{}`, " * len(columns))
737-
.rstrip(" ,")
738-
.format(*list(chain.from_iterable((column, column) for column in columns))),
739-
)
734+
table=safe_identifier_length(table["name"]),
735+
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
736+
values_clause=(
737+
"VALUES ({placeholders}) AS `__new__`"
738+
if check_mysql_values_alias_support(self._mysql_version)
739+
else "VALUES ({placeholders})"
740+
).format(placeholders=("%s, " * len(columns)).rstrip(" ,")),
741+
field_updates=(
742+
("`{}`=`__new__`.`{}`, " * len(columns)).rstrip(" ,")
743+
if check_mysql_values_alias_support(self._mysql_version)
744+
else ("`{}`=`{}`, " * len(columns)).rstrip(" ,")
745+
).format(*list(chain.from_iterable((column, column) for column in columns))),
740746
)
741747
else:
742748
sql = """

‎tests/func/sqlite3_to_mysql_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_valid_sqlite_file_and_valid_mysql_credentials(
5959
mysql_credentials: MySQLCredentials,
6060
helpers: Helpers,
6161
quiet: bool,
62-
):
62+
) -> None:
6363
with helpers.not_raises(FileNotFoundError):
6464
SQLite3toMySQL( # type: ignore
6565
sqlite_file=sqlite_database,

‎tests/unit/mysql_utils_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
from packaging.version import Version
3+
4+
from sqlite3_to_mysql.mysql_utils import (
5+
check_mysql_fulltext_support,
6+
check_mysql_json_support,
7+
check_mysql_values_alias_support,
8+
get_mysql_version,
9+
safe_identifier_length,
10+
)
11+
12+
13+
class TestMySQLUtils:
14+
@pytest.mark.parametrize(
15+
"version_string,expected",
16+
[
17+
("5.7.7", Version("5.7.7")),
18+
("5.7.8", Version("5.7.8")),
19+
("8.0.0", Version("8.0.0")),
20+
("9.0.0", Version("9.0.0")),
21+
("10.2.6-mariadb", Version("10.2.6")),
22+
("10.2.7-mariadb", Version("10.2.7")),
23+
("11.4.0-mariadb", Version("11.4.0")),
24+
],
25+
)
26+
def test_get_mysql_version(self, version_string: str, expected: Version) -> None:
27+
assert get_mysql_version(version_string) == expected
28+
29+
@pytest.mark.parametrize(
30+
"version_string,expected",
31+
[
32+
("5.7.7", False),
33+
("5.7.8", True),
34+
("8.0.0", True),
35+
("9.0.0", True),
36+
("10.2.6-mariadb", False),
37+
("10.2.7-mariadb", True),
38+
("11.4.0-mariadb", True),
39+
],
40+
)
41+
def test_check_mysql_json_support(self, version_string: str, expected: bool) -> None:
42+
assert check_mysql_json_support(version_string) == expected
43+
44+
@pytest.mark.parametrize(
45+
"version_string,expected",
46+
[
47+
("5.7.8", False),
48+
("8.0.0", False),
49+
("8.0.18", False),
50+
("8.0.19", True),
51+
("9.0.0", True),
52+
("10.2.6-mariadb", False),
53+
("10.2.7-mariadb", False),
54+
("11.4.0-mariadb", False),
55+
],
56+
)
57+
def test_check_mysql_values_alias_support(self, version_string: str, expected: bool) -> None:
58+
assert check_mysql_values_alias_support(version_string) == expected
59+
60+
@pytest.mark.parametrize(
61+
"version_string,expected",
62+
[
63+
("5.0.0", False),
64+
("5.5.0", False),
65+
("5.6.0", True),
66+
("8.0.0", True),
67+
("10.0.4-mariadb", False),
68+
("10.0.5-mariadb", True),
69+
("10.2.6-mariadb", True),
70+
("11.4.0-mariadb", True),
71+
],
72+
)
73+
def test_check_mysql_fulltext_support(self, version_string: str, expected: bool) -> None:
74+
assert check_mysql_fulltext_support(version_string) == expected
75+
76+
@pytest.mark.parametrize(
77+
"identifier,expected",
78+
[
79+
("a" * 67, "a" * 64),
80+
("a" * 66, "a" * 64),
81+
("a" * 65, "a" * 64),
82+
("a" * 64, "a" * 64),
83+
("a" * 63, "a" * 63),
84+
],
85+
)
86+
def test_safe_identifier_length(self, identifier: str, expected: str) -> None:
87+
assert safe_identifier_length(identifier) == expected

0 commit comments

Comments
 (0)