From 85a8427656895a2d80dc28ca64432289af6802ac Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Thu, 22 Feb 2024 12:15:54 +0100 Subject: [PATCH] instrumentation: make query scanning for dollar quotes a bit more correct Dollar quotes follow the same rules as SQL identifiers so: SQL identifiers and key words must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). Subsequent characters in an identifier or key word can be letters, underscores, digits (0-9), or dollar signs ($). Given we are not going to write a compliant SQL parser at least handle query parameters that are simple to catch since they have a digit right after the opening $. Refs #1851. --- elasticapm/instrumentation/packages/dbapi2.py | 9 +++++++++ tests/instrumentation/dbapi2_tests.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/elasticapm/instrumentation/packages/dbapi2.py b/elasticapm/instrumentation/packages/dbapi2.py index fb49723c28..4f3f031799 100644 --- a/elasticapm/instrumentation/packages/dbapi2.py +++ b/elasticapm/instrumentation/packages/dbapi2.py @@ -34,6 +34,7 @@ """ import re +import string import wrapt @@ -65,6 +66,7 @@ def look_for_table(sql, keyword): def _scan_for_table_with_tokens(tokens, keyword): seen_keyword = False for idx, lexeme in scan(tokens): + print(idx, lexeme) if seen_keyword: if lexeme == "(": return _scan_for_table_with_tokens(tokens[idx:], keyword) @@ -85,6 +87,7 @@ def scan(tokens): literal_started = None prev_was_escape = False lexeme = [] + digits = set(string.digits) i = 0 while i < len(tokens): @@ -114,6 +117,11 @@ def scan(tokens): literal_start_idx = i literal_started = token elif token == "$": + # exclude query parameters that have a digit following the dollar + if True and len(tokens) > i + 1 and tokens[i + 1] in digits: + yield i, token + i += 1 + continue # Postgres can use arbitrary characters between two $'s as a # literal separation token, e.g.: $fish$ literal $fish$ # This part will detect that and skip over the literal. @@ -125,6 +133,7 @@ def scan(tokens): pass else: quote = tokens[i : closing_dollar_idx + 1] + length = len(quote) # Opening dollar of the closing quote, # i.e. the first $ in the second $fish$ diff --git a/tests/instrumentation/dbapi2_tests.py b/tests/instrumentation/dbapi2_tests.py index 3d72b66327..2faaf462d6 100644 --- a/tests/instrumentation/dbapi2_tests.py +++ b/tests/instrumentation/dbapi2_tests.py @@ -122,6 +122,20 @@ def test_extract_signature_bytes(): assert actual == expected +def test_extract_signature_pathological(): + # tune for performance testing + multiplier = 10 + values = [] + for chunk in range(multiplier): + i = chunk * 3 + values.append(f" (${1+i}::varchar, ${2+i}::varchar, ${3+i}::varchar), ") + + sql = f"SELECT * FROM (VALUES {''.join(values)})\n" + actual = extract_signature(sql) + expected = "SELECT FROM" + assert actual == expected + + @pytest.mark.parametrize( ["sql", "expected"], [