Skip to content

Commit 3264823

Browse files
dianakramerjkppr
andauthored
Filtered back-ticks and other trailing characters from the resulting query (#3304)
* Filtered back-ticks and other trailing characters from the resulting query. --------- Co-authored-by: Janosch <99879757+jkppr@users.noreply.github.com>
1 parent d866890 commit 3264823

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

timesketch/api/v1/resources/nl2q.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,5 +222,5 @@ def post(self, sketch_id):
222222
)
223223
return jsonify(result_schema)
224224

225-
result_schema["query_string"] = prediction.strip("```")
225+
result_schema["query_string"] = prediction.strip("`\n\r\t ")
226226
return jsonify(result_schema)

timesketch/api/v1/resources_test.py

+81
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,87 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider):
14081408
data = json.loads(response.get_data(as_text=True))
14091409
self.assertIsNotNone(data.get("error"))
14101410

1411+
@mock.patch("timesketch.lib.llms.providers.manager.LLMManager.create_provider")
1412+
@mock.patch("timesketch.api.v1.utils.run_aggregator")
1413+
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
1414+
def test_nl2q_strip_back_ticks(self, mock_aggregator, mock_create_provider):
1415+
"""Test the result does not have any back tick."""
1416+
1417+
self.login()
1418+
data = dict(question="Question for LLM?")
1419+
mock_AggregationResult = mock.MagicMock()
1420+
mock_AggregationResult.values = [
1421+
{"data_type": "test:data_type:1"},
1422+
{"data_type": "test:data_type:2"},
1423+
]
1424+
mock_aggregator.return_value = (mock_AggregationResult, {})
1425+
expected_input = (
1426+
"Examples:\n"
1427+
"example 1\n"
1428+
"\n"
1429+
"example 2\n"
1430+
"Types:\n"
1431+
'* "test:data_type:1" -> "field_test_1", "field_test_2"\n'
1432+
'* "test:data_type:2" -> "field_test_3", "field_test_4"\n'
1433+
"Question:\n"
1434+
"Question for LLM?"
1435+
)
1436+
1437+
mock_llm_1 = mock.Mock()
1438+
mock_llm_1.generate.return_value = " \t`LLM generated query`\n "
1439+
mock_create_provider.return_value = mock_llm_1
1440+
response = self.client.post(
1441+
self.resource_url,
1442+
data=json.dumps(data),
1443+
content_type="application/json",
1444+
)
1445+
mock_llm_1.generate.assert_called_once_with(expected_input)
1446+
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
1447+
self.assertDictEqual(
1448+
response.json,
1449+
{
1450+
"name": "AI generated search query",
1451+
"query_string": "LLM generated query",
1452+
"error": None,
1453+
},
1454+
)
1455+
mock_llm_2 = mock.Mock()
1456+
mock_llm_2.generate.return_value = "```LLM generated query``"
1457+
mock_create_provider.return_value = mock_llm_2
1458+
response = self.client.post(
1459+
self.resource_url,
1460+
data=json.dumps(data),
1461+
content_type="application/json",
1462+
)
1463+
mock_llm_2.generate.assert_called_once_with(expected_input)
1464+
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
1465+
self.assertDictEqual(
1466+
response.json,
1467+
{
1468+
"name": "AI generated search query",
1469+
"query_string": "LLM generated query",
1470+
"error": None,
1471+
},
1472+
)
1473+
mock_llm_3 = mock.Mock()
1474+
mock_llm_3.generate.return_value = " \t```LLM generated query```\n "
1475+
mock_create_provider.return_value = mock_llm_3
1476+
response = self.client.post(
1477+
self.resource_url,
1478+
data=json.dumps(data),
1479+
content_type="application/json",
1480+
)
1481+
mock_llm_3.generate.assert_called_once_with(expected_input)
1482+
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
1483+
self.assertDictEqual(
1484+
response.json,
1485+
{
1486+
"name": "AI generated search query",
1487+
"query_string": "LLM generated query",
1488+
"error": None,
1489+
},
1490+
)
1491+
14111492

14121493
class SystemSettingsResourceTest(BaseTest):
14131494
"""Test system settings resource."""

0 commit comments

Comments
 (0)