From b82a7aa6cf135f33d16eeed3505babaa8868791e Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Thu, 22 Feb 2024 12:15:54 +0100 Subject: [PATCH] instrumentation/dbapi2: make query scanning for dollar quotes a bit more correct Dollar quotes follow the same rules as SQL identifiers so: SQL identifiers and key words must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). Subsequent characters in an identifier or key word can be letters, underscores, digits (0-9), or dollar signs ($). Given we are not going to write a compliant SQL parser at least handle query parameters that are simple to catch since they have a digit right after the opening $. Refs #1851. --- elasticapm/instrumentation/packages/dbapi2.py | 7 +++++++ tests/instrumentation/dbapi2_tests.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/elasticapm/instrumentation/packages/dbapi2.py b/elasticapm/instrumentation/packages/dbapi2.py index fb49723c2..fa1d0f31e 100644 --- a/elasticapm/instrumentation/packages/dbapi2.py +++ b/elasticapm/instrumentation/packages/dbapi2.py @@ -34,6 +34,7 @@ """ import re +import string import wrapt @@ -85,6 +86,7 @@ def scan(tokens): literal_started = None prev_was_escape = False lexeme = [] + digits = set(string.digits) i = 0 while i < len(tokens): @@ -114,6 +116,11 @@ def scan(tokens): literal_start_idx = i literal_started = token elif token == "$": + # exclude query parameters that have a digit following the dollar + if True and len(tokens) > i + 1 and tokens[i + 1] in digits: + yield i, token + i += 1 + continue # Postgres can use arbitrary characters between two $'s as a # literal separation token, e.g.: $fish$ literal $fish$ # This part will detect that and skip over the literal. diff --git a/tests/instrumentation/dbapi2_tests.py b/tests/instrumentation/dbapi2_tests.py index 3d72b6632..089571715 100644 --- a/tests/instrumentation/dbapi2_tests.py +++ b/tests/instrumentation/dbapi2_tests.py @@ -122,6 +122,20 @@ def test_extract_signature_bytes(): assert actual == expected +def test_extract_signature_pathological(): + # tune for performance testing + multiplier = 10 + values = [] + for chunk in range(multiplier): + i = chunk * 3 + values.append(f" (${1+i}::varchar, ${2+i}::varchar, ${3+i}::varchar), ") + + sql = f"SELECT * FROM (VALUES {''.join(values)})\n" + actual = extract_signature(sql) + expected = "SELECT FROM" + assert actual == expected + + @pytest.mark.parametrize( ["sql", "expected"], [