Skip to content

Commit 93db8f5

Browse files
authoredOct 26, 2024
Merge branch 'master' into socket
2 parents 25a9cd7 + 58b5f91 commit 93db8f5

File tree

4 files changed

+128
-34
lines changed

4 files changed

+128
-34
lines changed
 

‎src/sqlite3_to_mysql/mysql_utils.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from mysql.connector import CharacterSet
77
from mysql.connector.charsets import MYSQL_CHARACTER_SETS
88
from packaging import version
9-
from packaging.version import Version
109

1110

1211
# Shamelessly copied from SQLAlchemy's dialects/mysql/__init__.py
@@ -112,30 +111,32 @@ def get_mysql_version(version_string: str) -> version.Version:
112111

113112
def check_mysql_json_support(version_string: str) -> bool:
114113
"""Check for MySQL JSON support."""
115-
mysql_version: Version = get_mysql_version(version_string)
116-
if version_string.lower().endswith("-mariadb"):
117-
if mysql_version.major >= 10 and mysql_version.minor >= 2 and mysql_version.micro >= 7:
118-
return True
119-
else:
120-
if mysql_version.major >= 8:
121-
return True
122-
if mysql_version.minor >= 7 and mysql_version.micro >= 8:
123-
return True
124-
return False
114+
mysql_version: version.Version = get_mysql_version(version_string)
115+
if "-mariadb" in version_string.lower():
116+
return mysql_version >= version.parse("10.2.7")
117+
return mysql_version >= version.parse("5.7.8")
118+
119+
120+
def check_mysql_values_alias_support(version_string: str) -> bool:
121+
"""Check for VALUES alias support.
122+
123+
Returns:
124+
bool: True if VALUES alias is supported (MySQL 8.0.19+), False for MariaDB
125+
or older MySQL versions.
126+
"""
127+
mysql_version: version.Version = get_mysql_version(version_string)
128+
if "-mariadb" in version_string.lower():
129+
return False
130+
# Only MySQL 8.0.19 and later support VALUES alias
131+
return mysql_version >= version.parse("8.0.19")
125132

126133

127134
def check_mysql_fulltext_support(version_string: str) -> bool:
128135
"""Check for FULLTEXT indexing support."""
129-
mysql_version: Version = get_mysql_version(version_string)
130-
if version_string.lower().endswith("-mariadb"):
131-
if mysql_version.major >= 10 and mysql_version.minor >= 0 and mysql_version.micro >= 5:
132-
return True
133-
else:
134-
if mysql_version.major >= 8:
135-
return True
136-
if mysql_version.minor >= 6:
137-
return True
138-
return False
136+
mysql_version: version.Version = get_mysql_version(version_string)
137+
if "-mariadb" in version_string.lower():
138+
return mysql_version >= version.parse("10.0.5")
139+
return mysql_version >= version.parse("5.6.0")
139140

140141

141142
def safe_identifier_length(identifier_name: str, max_length: int = 64) -> str:

‎src/sqlite3_to_mysql/transporter.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
MYSQL_TEXT_COLUMN_TYPES_WITH_JSON,
4040
check_mysql_fulltext_support,
4141
check_mysql_json_support,
42+
check_mysql_values_alias_support,
4243
safe_identifier_length,
4344
)
4445
from .types import SQLite3toMySQLAttributes, SQLite3toMySQLParams
@@ -68,7 +69,7 @@ def __init__(self, **kwargs: tx.Unpack[SQLite3toMySQLParams]):
6869
else:
6970
raise ValueError("Please provide a MySQL user")
7071

71-
self._mysql_password = str(kwargs.get("mysql_password")) or None
72+
self._mysql_password = str(kwargs.get("mysql_password")) if kwargs.get("mysql_password") else None
7273

7374
self._mysql_host = str(kwargs.get("mysql_host", "localhost"))
7475

@@ -90,7 +91,7 @@ def __init__(self, **kwargs: tx.Unpack[SQLite3toMySQLParams]):
9091

9192
self._mysql_database = kwargs.get("mysql_database", "transfer") or "transfer"
9293

93-
self._mysql_insert_method = str(kwargs.get("mysql_integer_type", "IGNORE")).upper()
94+
self._mysql_insert_method = str(kwargs.get("mysql_insert_method", "IGNORE")).upper()
9495
if self._mysql_insert_method not in MYSQL_INSERT_METHOD:
9596
self._mysql_insert_method = "IGNORE"
9697

@@ -725,21 +726,26 @@ def transfer(self) -> None:
725726
columns: t.List[str] = [
726727
safe_identifier_length(column[0]) for column in self._sqlite_cur.description
727728
]
729+
sql: str
728730
if self._mysql_insert_method.upper() == "UPDATE":
729-
sql: str = (
730-
"""
731+
sql = """
731732
INSERT
732733
INTO `{table}` ({fields})
733-
VALUES ({placeholders}) AS `__new__`
734+
{values_clause}
734735
ON DUPLICATE KEY UPDATE {field_updates}
735736
""".format(
736-
table=safe_identifier_length(table["name"]),
737-
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
738-
placeholders=("%s, " * len(columns)).rstrip(" ,"),
739-
field_updates=("`{}`=`__new__`.`{}`, " * len(columns))
740-
.rstrip(" ,")
741-
.format(*list(chain.from_iterable((column, column) for column in columns))),
742-
)
737+
table=safe_identifier_length(table["name"]),
738+
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
739+
values_clause=(
740+
"VALUES ({placeholders}) AS `__new__`"
741+
if check_mysql_values_alias_support(self._mysql_version)
742+
else "VALUES ({placeholders})"
743+
).format(placeholders=("%s, " * len(columns)).rstrip(" ,")),
744+
field_updates=(
745+
("`{}`=`__new__`.`{}`, " * len(columns)).rstrip(" ,")
746+
if check_mysql_values_alias_support(self._mysql_version)
747+
else ("`{}`=`{}`, " * len(columns)).rstrip(" ,")
748+
).format(*list(chain.from_iterable((column, column) for column in columns))),
743749
)
744750
else:
745751
sql = """

‎tests/func/sqlite3_to_mysql_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_valid_sqlite_file_and_valid_mysql_credentials(
5959
mysql_credentials: MySQLCredentials,
6060
helpers: Helpers,
6161
quiet: bool,
62-
):
62+
) -> None:
6363
with helpers.not_raises(FileNotFoundError):
6464
SQLite3toMySQL( # type: ignore
6565
sqlite_file=sqlite_database,

‎tests/unit/mysql_utils_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
from packaging.version import Version
3+
4+
from sqlite3_to_mysql.mysql_utils import (
5+
check_mysql_fulltext_support,
6+
check_mysql_json_support,
7+
check_mysql_values_alias_support,
8+
get_mysql_version,
9+
safe_identifier_length,
10+
)
11+
12+
13+
class TestMySQLUtils:
14+
@pytest.mark.parametrize(
15+
"version_string,expected",
16+
[
17+
("5.7.7", Version("5.7.7")),
18+
("5.7.8", Version("5.7.8")),
19+
("8.0.0", Version("8.0.0")),
20+
("9.0.0", Version("9.0.0")),
21+
("10.2.6-mariadb", Version("10.2.6")),
22+
("10.2.7-mariadb", Version("10.2.7")),
23+
("11.4.0-mariadb", Version("11.4.0")),
24+
],
25+
)
26+
def test_get_mysql_version(self, version_string: str, expected: Version) -> None:
27+
assert get_mysql_version(version_string) == expected
28+
29+
@pytest.mark.parametrize(
30+
"version_string,expected",
31+
[
32+
("5.7.7", False),
33+
("5.7.8", True),
34+
("8.0.0", True),
35+
("9.0.0", True),
36+
("10.2.6-mariadb", False),
37+
("10.2.7-mariadb", True),
38+
("11.4.0-mariadb", True),
39+
],
40+
)
41+
def test_check_mysql_json_support(self, version_string: str, expected: bool) -> None:
42+
assert check_mysql_json_support(version_string) == expected
43+
44+
@pytest.mark.parametrize(
45+
"version_string,expected",
46+
[
47+
("5.7.8", False),
48+
("8.0.0", False),
49+
("8.0.18", False),
50+
("8.0.19", True),
51+
("9.0.0", True),
52+
("10.2.6-mariadb", False),
53+
("10.2.7-mariadb", False),
54+
("11.4.0-mariadb", False),
55+
],
56+
)
57+
def test_check_mysql_values_alias_support(self, version_string: str, expected: bool) -> None:
58+
assert check_mysql_values_alias_support(version_string) == expected
59+
60+
@pytest.mark.parametrize(
61+
"version_string,expected",
62+
[
63+
("5.0.0", False),
64+
("5.5.0", False),
65+
("5.6.0", True),
66+
("8.0.0", True),
67+
("10.0.4-mariadb", False),
68+
("10.0.5-mariadb", True),
69+
("10.2.6-mariadb", True),
70+
("11.4.0-mariadb", True),
71+
],
72+
)
73+
def test_check_mysql_fulltext_support(self, version_string: str, expected: bool) -> None:
74+
assert check_mysql_fulltext_support(version_string) == expected
75+
76+
@pytest.mark.parametrize(
77+
"identifier,expected",
78+
[
79+
("a" * 67, "a" * 64),
80+
("a" * 66, "a" * 64),
81+
("a" * 65, "a" * 64),
82+
("a" * 64, "a" * 64),
83+
("a" * 63, "a" * 63),
84+
],
85+
)
86+
def test_safe_identifier_length(self, identifier: str, expected: str) -> None:
87+
assert safe_identifier_length(identifier) == expected

0 commit comments

Comments
 (0)