From b856f54c3069789b3eaf5d49b0fafad1e7641901 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 13:17:01 +0000 Subject: [PATCH 01/22] add response_schema support to ollama.py --- timesketch/lib/llms/providers/ollama.py | 93 +++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 timesketch/lib/llms/providers/ollama.py diff --git a/timesketch/lib/llms/providers/ollama.py b/timesketch/lib/llms/providers/ollama.py new file mode 100644 index 0000000000..75d83c112b --- /dev/null +++ b/timesketch/lib/llms/providers/ollama.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A LLM provider for the ollama server.""" +import json +import requests +from typing import Optional + +from timesketch.lib.llms.providers import interface +from timesketch.lib.llms.providers import manager + + +class Ollama(interface.LLMProvider): + """A LLM provider for the ollama server.""" + + NAME = "ollama" + + def _post(self, request_body: str) -> requests.Response: + """ + Make a POST request to the ollama server. + + Args: + request_body: The body of the request in JSON format. + + Returns: + The response from the server as a dictionary. + """ + api_resource = "/api/chat" + url = self.config.get("server_url") + api_resource + return requests.post(url, data=request_body) + + def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: + """ + Generate text using the ollama server, optionally with a JSON schema. + + Args: + prompt: The prompt to use for the generation. + response_schema: An optional JSON schema to define the expected + response format. + + Returns: + The generated text as a string (or parsed data if + response_schema is provided). + """ + request_body = { + "messages": [{"role": "user", "content": prompt}], + "model": self.config.get("model"), + "stream": False, # Force to false, streaming not available with /api/chat endpoint + "options": { + "temperature": self.config.get("temperature"), + "num_predict": self.config.get("max_output_tokens"), + "top_p": self.config.get("top_p"), + "top_k": self.config.get("top_k"), + }, + } + + if response_schema: + request_body["format"] = response_schema + + response = self._post(json.dumps(request_body)) + + if response.status_code != 200: + raise ValueError(f"Error generating text: {response.text}") + + try: + text_response = response.json().get("content", "").strip() + if response_schema: + return json.loads(text_response) + + return text_response + + except json.JSONDecodeError as error: + raise ValueError( + f"Error JSON parsing text: {text_response}: {error}" + ) from error + + except Exception as error: + raise ValueError( + f"An unexpected error occurred: {error}" + ) from error + + +manager.LLMManager.register_provider(Ollama) From 5debf0f999ed3038a07c3616456e06b17537f151 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 15:46:59 +0000 Subject: [PATCH 02/22] Create separate llm provider directory, add response_schema to ollama provider --- timesketch/api/v1/resources/llm_summarize.py | 7 +- timesketch/api/v1/resources/nl2q.py | 2 +- timesketch/api/v1/resources_test.py | 6 +- timesketch/lib/llms/ollama.py | 72 ------------------- .../lib/llms/{ => providers}/__init__.py | 8 +-- .../lib/llms/{ => providers}/aistudio.py | 6 +- .../lib/llms/{ => providers}/interface.py | 0 .../lib/llms/{ => providers}/manager.py | 3 +- .../lib/llms/{ => providers}/manager_test.py | 2 +- timesketch/lib/llms/providers/ollama.py | 48 ++++++------- .../lib/llms/{ => providers}/vertexai.py | 4 +- 11 files changed, 43 insertions(+), 115 deletions(-) delete mode 100644 timesketch/lib/llms/ollama.py rename timesketch/lib/llms/{ => providers}/__init__.py (76%) rename timesketch/lib/llms/{ => providers}/aistudio.py (95%) rename timesketch/lib/llms/{ => providers}/interface.py (100%) rename timesketch/lib/llms/{ => providers}/manager.py (98%) rename timesketch/lib/llms/{ => providers}/manager_test.py (99%) rename timesketch/lib/llms/{ => providers}/vertexai.py (96%) diff --git a/timesketch/api/v1/resources/llm_summarize.py b/timesketch/api/v1/resources/llm_summarize.py index 0c18441b56..5aa37657f1 100644 --- a/timesketch/api/v1/resources/llm_summarize.py +++ b/timesketch/api/v1/resources/llm_summarize.py @@ -28,7 +28,8 @@ from flask_restful import Resource from timesketch.api.v1 import resources, export -from timesketch.lib import definitions, llms, utils +from timesketch.lib import definitions, utils +from timesketch.lib.llms.providers import manager from timesketch.lib.definitions import METRICS_NAMESPACE from timesketch.models.sketch import Sketch @@ -304,8 +305,8 @@ def _get_content( configured LLM provider """ try: - feature_name = "llm_summarization" - llm = llms.manager.LLMManager.create_provider(feature_name=feature_name) + feature_name = "llm_summarize" + llm = manager.LLMManager.create_provider(feature_name=feature_name) except Exception as e: # pylint: disable=broad-except logger.error("Error LLM Provider: %s", e) abort( diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index d016a768f7..5ed533e956 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -26,7 +26,7 @@ import pandas as pd from timesketch.api.v1 import utils -from timesketch.lib.llms import manager +from timesketch.lib.llms.providers import manager from timesketch.lib.definitions import HTTP_STATUS_CODE_BAD_REQUEST from timesketch.lib.definitions import HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR from timesketch.lib.definitions import HTTP_STATUS_CODE_NOT_FOUND diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index a964fad50a..7044bb2250 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1198,7 +1198,7 @@ class TestNl2qResource(BaseTest): resource_url = "/api/v1/sketches/1/nl2q/" - @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): @@ -1380,7 +1380,7 @@ def test_nl2q_no_permission(self): ) self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN) - @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider): @@ -1584,7 +1584,7 @@ def test_llm_summarize_no_events(self): ) @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") def test_llm_summarize_with_events(self, mock_create_provider): """Test LLM summarizer with events returned and mock LLM.""" self.login() diff --git a/timesketch/lib/llms/ollama.py b/timesketch/lib/llms/ollama.py deleted file mode 100644 index 365716b580..0000000000 --- a/timesketch/lib/llms/ollama.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A LLM provider for the ollama server.""" -import json -import requests - -from timesketch.lib.llms import interface -from timesketch.lib.llms import manager - - -class Ollama(interface.LLMProvider): - """A LLM provider for the ollama server.""" - - NAME = "ollama" - - def _post(self, request_body: str) -> requests.Response: - """ - Make a POST request to the ollama server. - - Args: - request_body: The body of the request in JSON format. - - Returns: - The response from the server as a dictionary. - """ - api_resource = "/api/generate/" - url = self.config.get("server_url") + api_resource - return requests.post(url, data=request_body) - - def generate(self, prompt: str) -> str: - """ - Generate text using the ollama server. - - Args: - prompt: The prompt to use for the generation. - temperature: The temperature to use for the generation. - stream: Whether to stream the generation or not. - - Raises: - ValueError: If the generation fails. - - Returns: - The generated text as a string. - """ - request_body = { - "prompt": prompt, - "model": self.config.get("model"), - "stream": self.config.get("stream"), - "options": { - "temperature": self.config.get("temperature"), - "num_predict": self.config.get("max_output_tokens"), - }, - } - response = self._post(json.dumps(request_body)) - if response.status_code != 200: - raise ValueError(f"Error generating text: {response.text}") - - return response.json().get("response", "").strip() - - -manager.LLMManager.register_provider(Ollama) diff --git a/timesketch/lib/llms/__init__.py b/timesketch/lib/llms/providers/__init__.py similarity index 76% rename from timesketch/lib/llms/__init__.py rename to timesketch/lib/llms/providers/__init__.py index bb52e18d42..f92027460b 100644 --- a/timesketch/lib/llms/__init__.py +++ b/timesketch/lib/llms/providers/__init__.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""LLM module for Timesketch.""" +"""LLM providers for Timesketch.""" -from timesketch.lib.llms import ollama -from timesketch.lib.llms import vertexai -from timesketch.lib.llms import aistudio +from timesketch.lib.llms.providers import ollama +from timesketch.lib.llms.providers import vertexai +from timesketch.lib.llms.providers import aistudio diff --git a/timesketch/lib/llms/aistudio.py b/timesketch/lib/llms/providers/aistudio.py similarity index 95% rename from timesketch/lib/llms/aistudio.py rename to timesketch/lib/llms/providers/aistudio.py index 77b6502efa..df7d5ca1bb 100644 --- a/timesketch/lib/llms/aistudio.py +++ b/timesketch/lib/llms/providers/aistudio.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google Inc. All rights reserved. +# Copyright 2025 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ import json from typing import Optional -from timesketch.lib.llms import interface -from timesketch.lib.llms import manager +from timesketch.lib.llms.providers import interface +from timesketch.lib.llms.providers import manager # Check if the required dependencies are installed. diff --git a/timesketch/lib/llms/interface.py b/timesketch/lib/llms/providers/interface.py similarity index 100% rename from timesketch/lib/llms/interface.py rename to timesketch/lib/llms/providers/interface.py diff --git a/timesketch/lib/llms/manager.py b/timesketch/lib/llms/providers/manager.py similarity index 98% rename from timesketch/lib/llms/manager.py rename to timesketch/lib/llms/providers/manager.py index 5412abcec6..6bb3757d1d 100644 --- a/timesketch/lib/llms/manager.py +++ b/timesketch/lib/llms/providers/manager.py @@ -14,7 +14,7 @@ """This file contains a class for managing Large Language Model (LLM) providers.""" from flask import current_app -from timesketch.lib.llms.interface import LLMProvider +from timesketch.lib.llms.providers.interface import LLMProvider class LLMManager: @@ -80,7 +80,6 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: raise ValueError( "Configuration for the feature must specify exactly one provider." ) - provider_name = next(iter(config_mapping)) provider_config = config_mapping[provider_name] diff --git a/timesketch/lib/llms/manager_test.py b/timesketch/lib/llms/providers/manager_test.py similarity index 99% rename from timesketch/lib/llms/manager_test.py rename to timesketch/lib/llms/providers/manager_test.py index c850b6a75c..af5b5f4e95 100644 --- a/timesketch/lib/llms/manager_test.py +++ b/timesketch/lib/llms/providers/manager_test.py @@ -14,7 +14,7 @@ """Tests for LLM provider manager.""" from timesketch.lib.testlib import BaseTest -from timesketch.lib.llms import manager +from timesketch.lib.llms.providers import manager class MockAistudioProvider: diff --git a/timesketch/lib/llms/providers/ollama.py b/timesketch/lib/llms/providers/ollama.py index 75d83c112b..bbb6795887 100644 --- a/timesketch/lib/llms/providers/ollama.py +++ b/timesketch/lib/llms/providers/ollama.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A LLM provider for the ollama server.""" +"""A LLM provider for the Ollama server.""" import json import requests from typing import Optional @@ -21,27 +21,29 @@ class Ollama(interface.LLMProvider): - """A LLM provider for the ollama server.""" + """A LLM provider for the Ollama server.""" NAME = "ollama" def _post(self, request_body: str) -> requests.Response: """ - Make a POST request to the ollama server. + Make a POST request to the Ollama server. Args: request_body: The body of the request in JSON format. Returns: - The response from the server as a dictionary. + The response from the server as a requests.Response object. """ api_resource = "/api/chat" url = self.config.get("server_url") + api_resource - return requests.post(url, data=request_body) + return requests.post( + url, data=request_body, headers={"Content-Type": "application/json"} + ) def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: """ - Generate text using the ollama server, optionally with a JSON schema. + Generate text using the Ollama server, optionally with a JSON schema. Args: prompt: The prompt to use for the generation. @@ -49,13 +51,15 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: response format. Returns: - The generated text as a string (or parsed data if - response_schema is provided). + The generated text as a string (or parsed data if response_schema is provided). + + Raises: + ValueError: If the request fails or JSON parsing fails. """ request_body = { "messages": [{"role": "user", "content": prompt}], "model": self.config.get("model"), - "stream": False, # Force to false, streaming not available with /api/chat endpoint + "stream": self.config.get("stream"), "options": { "temperature": self.config.get("temperature"), "num_predict": self.config.get("max_output_tokens"), @@ -72,22 +76,18 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: if response.status_code != 200: raise ValueError(f"Error generating text: {response.text}") - try: - text_response = response.json().get("content", "").strip() - if response_schema: + response_data = response.json() + text_response = response_data.get("message", {}).get("content", "").strip() + + if response_schema: + try: return json.loads(text_response) - - return text_response - - except json.JSONDecodeError as error: - raise ValueError( - f"Error JSON parsing text: {text_response}: {error}" - ) from error - - except Exception as error: - raise ValueError( - f"An unexpected error occurred: {error}" - ) from error + except json.JSONDecodeError as error: + raise ValueError( + f"Error JSON parsing text: {text_response}: {error}" + ) from error + + return text_response manager.LLMManager.register_provider(Ollama) diff --git a/timesketch/lib/llms/vertexai.py b/timesketch/lib/llms/providers/vertexai.py similarity index 96% rename from timesketch/lib/llms/vertexai.py rename to timesketch/lib/llms/providers/vertexai.py index e4f25f7f7e..123bbdd39e 100644 --- a/timesketch/lib/llms/vertexai.py +++ b/timesketch/lib/llms/providers/vertexai.py @@ -16,8 +16,8 @@ import json from typing import Optional -from timesketch.lib.llms import interface -from timesketch.lib.llms import manager +from timesketch.lib.llms.providers import interface +from timesketch.lib.llms.providers import manager # Check if the required dependencies are installed. has_required_deps = True From 70d06991938b9f777c3b4632c7e4612820475e47 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 15:50:24 +0000 Subject: [PATCH 03/22] Update timesketch.conf --- data/timesketch.conf | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/data/timesketch.conf b/data/timesketch.conf index 0d9a47164c..1df853cc12 100644 --- a/data/timesketch.conf +++ b/data/timesketch.conf @@ -379,16 +379,16 @@ LLM_PROVIDER_CONFIGS = { 'project_id': '', }, }, - 'llm_summarization': { + 'llm_summarize': { 'aistudio': { 'model': 'gemini-2.0-flash-exp', 'project_id': '', }, }, 'default': { - 'aistudio': { - 'api_key': '', - 'model': 'gemini-2.0-flash-exp', + 'ollama': { + 'server_url': 'http://ollama:11434', + 'model': 'gemma:7b', }, } } From 59ce086c88c7eb4c23499edf194fa4b1d955e686 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 15:53:20 +0000 Subject: [PATCH 04/22] solve naming conflict --- timesketch/api/v1/resources/llm_summarize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timesketch/api/v1/resources/llm_summarize.py b/timesketch/api/v1/resources/llm_summarize.py index 5aa37657f1..a5ecebc3b6 100644 --- a/timesketch/api/v1/resources/llm_summarize.py +++ b/timesketch/api/v1/resources/llm_summarize.py @@ -29,7 +29,7 @@ from timesketch.api.v1 import resources, export from timesketch.lib import definitions, utils -from timesketch.lib.llms.providers import manager +from timesketch.lib.llms.providers import manager as provider_manager from timesketch.lib.definitions import METRICS_NAMESPACE from timesketch.models.sketch import Sketch @@ -306,7 +306,7 @@ def _get_content( """ try: feature_name = "llm_summarize" - llm = manager.LLMManager.create_provider(feature_name=feature_name) + llm = provider_manager.LLMManager.create_provider(feature_name=feature_name) except Exception as e: # pylint: disable=broad-except logger.error("Error LLM Provider: %s", e) abort( From 9e2c294a796c12a4978dc1ffe22e103647687ccf Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 16:06:13 +0000 Subject: [PATCH 05/22] fix typo --- timesketch/api/v1/resources_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 7044bb2250..5396b010ab 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1198,7 +1198,7 @@ class TestNl2qResource(BaseTest): resource_url = "/api/v1/sketches/1/nl2q/" - @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.providers.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): @@ -1380,7 +1380,7 @@ def test_nl2q_no_permission(self): ) self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN) - @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.providers.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider): @@ -1584,7 +1584,7 @@ def test_llm_summarize_no_events(self): ) @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - @mock.patch("timesketch.lib.llms.provider.manager.LLMManager.create_provider") + @mock.patch("timesketch.lib.llms.providers.manager.LLMManager.create_provider") def test_llm_summarize_with_events(self, mock_create_provider): """Test LLM summarizer with events returned and mock LLM.""" self.login() From 5f252a94be82bc960c144eabcbbc4d7a74777a86 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 17:17:35 +0000 Subject: [PATCH 06/22] Add an __init__ file to the timsketch/lib/llms folder --- timesketch/lib/llms/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 timesketch/lib/llms/__init__.py diff --git a/timesketch/lib/llms/__init__.py b/timesketch/lib/llms/__init__.py new file mode 100644 index 0000000000..0242820fb1 --- /dev/null +++ b/timesketch/lib/llms/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLM libraries for Timesketch.""" From c0401596592a12f937709623997b51f1aa7bd2ff Mon Sep 17 00:00:00 2001 From: itsmvd Date: Tue, 25 Feb 2025 17:24:27 +0000 Subject: [PATCH 07/22] lint fix ollama --- timesketch/lib/llms/providers/ollama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timesketch/lib/llms/providers/ollama.py b/timesketch/lib/llms/providers/ollama.py index bbb6795887..42481e7aee 100644 --- a/timesketch/lib/llms/providers/ollama.py +++ b/timesketch/lib/llms/providers/ollama.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """A LLM provider for the Ollama server.""" +from typing import Optional import json import requests -from typing import Optional from timesketch.lib.llms.providers import interface from timesketch.lib.llms.providers import manager @@ -51,7 +51,8 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: response format. Returns: - The generated text as a string (or parsed data if response_schema is provided). + The generated text as a string (or parsed data if + response_schema is provided). Raises: ValueError: If the request fails or JSON parsing fails. From 9ab391ed42b744294d2fae940ba88282ead3ab58 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Wed, 26 Feb 2025 09:48:28 +0000 Subject: [PATCH 08/22] Improve fallback mechanism for LLM configs --- timesketch/lib/llms/providers/manager.py | 20 +++++++++++++----- timesketch/lib/llms/providers/manager_test.py | 21 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/timesketch/lib/llms/providers/manager.py b/timesketch/lib/llms/providers/manager.py index 6bb3757d1d..3dfc4705aa 100644 --- a/timesketch/lib/llms/providers/manager.py +++ b/timesketch/lib/llms/providers/manager.py @@ -63,7 +63,7 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: """ Create an instance of the provider for the given feature. - If a configuration exists for the feature in + If a valid configuration exists for the feature in current_app.config["LLM_PROVIDER_CONFIGS"], use it; otherwise, fall back to the configuration under the "default" key. @@ -71,14 +71,24 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: the provider name. """ llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) + if feature_name and feature_name in llm_configs: config_mapping = llm_configs[feature_name] - else: - config_mapping = llm_configs.get("default") - + if config_mapping and len(config_mapping) == 1: + provider_name = next(iter(config_mapping)) + provider_config = config_mapping[provider_name] + provider_class = cls.get_provider(provider_name) + # Check that provider specifies required fields + try: + return provider_class(config=provider_config, **kwargs) + except ValueError: + pass # Fallback to default provider + + # Fallback to default config + config_mapping = llm_configs.get("default") if not config_mapping or len(config_mapping) != 1: raise ValueError( - "Configuration for the feature must specify exactly one provider." + "Default configuration must specify exactly one provider." ) provider_name = next(iter(config_mapping)) provider_config = config_mapping[provider_name] diff --git a/timesketch/lib/llms/providers/manager_test.py b/timesketch/lib/llms/providers/manager_test.py index af5b5f4e95..6db3f6b3ce 100644 --- a/timesketch/lib/llms/providers/manager_test.py +++ b/timesketch/lib/llms/providers/manager_test.py @@ -144,3 +144,24 @@ def test_create_provider_missing_config(self): self.app.config["LLM_PROVIDER_CONFIGS"] = {} with self.assertRaises(ValueError): manager.LLMManager.create_provider() + + def test_create_provider_empty_feature_fallback(self): + """Test that create_provider falls back to default when feature config is empty.""" + self.app.config["LLM_PROVIDER_CONFIGS"] = { + "llm_summarize": {}, # Empty feature config + "default": { + "aistudio": { + "api_key": "AIzaSyTestDefaultKey", + "model": "gemini-2.0-flash-exp", + } + }, + } + provider_instance = manager.LLMManager.create_provider(feature_name="llm_summarize") + self.assertIsInstance(provider_instance, MockAistudioProvider) + self.assertEqual( + provider_instance.config, + { + "api_key": "AIzaSyTestDefaultKey", + "model": "gemini-2.0-flash-exp", + }, + ) \ No newline at end of file From 5d4746a22e9832455100adc89b66782aeded08bd Mon Sep 17 00:00:00 2001 From: itsmvd Date: Wed, 26 Feb 2025 09:51:16 +0000 Subject: [PATCH 09/22] formatting --- timesketch/lib/llms/providers/manager.py | 6 ++---- timesketch/lib/llms/providers/manager_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/timesketch/lib/llms/providers/manager.py b/timesketch/lib/llms/providers/manager.py index 3dfc4705aa..7cfc0a7574 100644 --- a/timesketch/lib/llms/providers/manager.py +++ b/timesketch/lib/llms/providers/manager.py @@ -71,7 +71,7 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: the provider name. """ llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) - + if feature_name and feature_name in llm_configs: config_mapping = llm_configs[feature_name] if config_mapping and len(config_mapping) == 1: @@ -87,9 +87,7 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: # Fallback to default config config_mapping = llm_configs.get("default") if not config_mapping or len(config_mapping) != 1: - raise ValueError( - "Default configuration must specify exactly one provider." - ) + raise ValueError("Default configuration must specify exactly one provider.") provider_name = next(iter(config_mapping)) provider_config = config_mapping[provider_name] diff --git a/timesketch/lib/llms/providers/manager_test.py b/timesketch/lib/llms/providers/manager_test.py index 6db3f6b3ce..ceb7267100 100644 --- a/timesketch/lib/llms/providers/manager_test.py +++ b/timesketch/lib/llms/providers/manager_test.py @@ -156,7 +156,9 @@ def test_create_provider_empty_feature_fallback(self): } }, } - provider_instance = manager.LLMManager.create_provider(feature_name="llm_summarize") + provider_instance = manager.LLMManager.create_provider( + feature_name="llm_summarize" + ) self.assertIsInstance(provider_instance, MockAistudioProvider) self.assertEqual( provider_instance.config, @@ -164,4 +166,4 @@ def test_create_provider_empty_feature_fallback(self): "api_key": "AIzaSyTestDefaultKey", "model": "gemini-2.0-flash-exp", }, - ) \ No newline at end of file + ) From 390cd091ab5a60e7e61b49904c186938153c50dd Mon Sep 17 00:00:00 2001 From: itsmvd Date: Wed, 26 Feb 2025 09:57:13 +0000 Subject: [PATCH 10/22] format fix 2 --- timesketch/lib/llms/providers/manager_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/lib/llms/providers/manager_test.py b/timesketch/lib/llms/providers/manager_test.py index ceb7267100..09902faa0f 100644 --- a/timesketch/lib/llms/providers/manager_test.py +++ b/timesketch/lib/llms/providers/manager_test.py @@ -146,7 +146,7 @@ def test_create_provider_missing_config(self): manager.LLMManager.create_provider() def test_create_provider_empty_feature_fallback(self): - """Test that create_provider falls back to default when feature config is empty.""" + """Test that create_provider falls back to default when feature config empty.""" self.app.config["LLM_PROVIDER_CONFIGS"] = { "llm_summarize": {}, # Empty feature config "default": { From ad4d70b303a74c5cba57d9a14bbf287fa596468e Mon Sep 17 00:00:00 2001 From: itsmvd Date: Wed, 26 Feb 2025 13:13:06 +0000 Subject: [PATCH 11/22] Add LLM features manager and interface --- timesketch/lib/llms/features/__init__.py | 16 +++ timesketch/lib/llms/features/interface.py | 53 +++++++ timesketch/lib/llms/features/manager.py | 65 +++++++++ timesketch/lib/llms/features/manager_test.py | 142 +++++++++++++++++++ 4 files changed, 276 insertions(+) create mode 100644 timesketch/lib/llms/features/__init__.py create mode 100644 timesketch/lib/llms/features/interface.py create mode 100644 timesketch/lib/llms/features/manager.py create mode 100644 timesketch/lib/llms/features/manager_test.py diff --git a/timesketch/lib/llms/features/__init__.py b/timesketch/lib/llms/features/__init__.py new file mode 100644 index 0000000000..6a8e4caaf4 --- /dev/null +++ b/timesketch/lib/llms/features/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLM features for Timesketch.""" + +from timesketch.lib.llms.features import manager diff --git a/timesketch/lib/llms/features/interface.py b/timesketch/lib/llms/features/interface.py new file mode 100644 index 0000000000..10317fe014 --- /dev/null +++ b/timesketch/lib/llms/features/interface.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interface for LLM features.""" + +from typing import Any, Optional +from abc import ABC, abstractmethod +from timesketch.models.sketch import Sketch + + +class LLMFeatureInterface(ABC): + """Interface for LLM features.""" + + NAME: str = "llm_feature_interface" # Must be overridden in subclasses + RESPONSE_SCHEMA: Optional[dict[str, Any]] = None + + @abstractmethod + def generate_prompt(self, sketch: Sketch, **kwargs: Any) -> str: + """Generates a prompt for the LLM. + + Args: + sketch_id: The ID of the sketch. + kwargs: Feature-specific keyword arguments for prompt generation. + + Returns: + The generated prompt string. + """ + raise NotImplementedError() + + @abstractmethod + def process_response(self, llm_response: str, **kwargs: Any) -> dict[str, Any]: + """Processes the raw LLM response. + + Args: + llm_response: The raw string response from the LLM provider. + kwargs: Feature-specific arguments. + + Returns: + A dictionary containing the processed response data, suitable for + returning from the API. Must include a "response" key with the + main result, and can optionally include other metadata. + """ + raise NotImplementedError() diff --git a/timesketch/lib/llms/features/manager.py b/timesketch/lib/llms/features/manager.py new file mode 100644 index 0000000000..70bfac4836 --- /dev/null +++ b/timesketch/lib/llms/features/manager.py @@ -0,0 +1,65 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Manager for LLM features.""" + +import logging +from timesketch.lib.llms.features.interface import LLMFeatureInterface + +logger = logging.getLogger("timesketch.llm.manager") + + +class FeatureManager: + """The manager for LLM features.""" + + _feature_registry = {} + + @classmethod + def register_feature(cls, feature_class: type[LLMFeatureInterface]): + """Register an LLM feature class.""" + feature_name = feature_class.NAME.lower() + if feature_name in cls._feature_registry: + raise ValueError(f"LLM Feature {feature_class.NAME} already registered") + cls._feature_registry[feature_name] = feature_class + # Optional: Add logging here + + @classmethod + def get_feature(cls, feature_name: str) -> type[LLMFeatureInterface]: + """Get a feature class by name.""" + try: + return cls._feature_registry[feature_name.lower()] + except KeyError as no_such_feature: + raise KeyError( + f"No such LLM feature: {feature_name.lower()}" + ) from no_such_feature + + @classmethod + def get_features(cls): + """Get all registered features. + + Yields: + A tuple of (feature_name, feature_class) + """ + for feature_name, feature_class in cls._feature_registry.items(): + yield feature_name, feature_class + + @classmethod + def get_feature_instance(cls, feature_name: str) -> LLMFeatureInterface: + """Get an instance of a feature by name.""" + feature_class = cls.get_feature(feature_name) + return feature_class() + + @classmethod + def clear_registration(cls): + """Clear all registered features.""" + cls._feature_registry = {} diff --git a/timesketch/lib/llms/features/manager_test.py b/timesketch/lib/llms/features/manager_test.py new file mode 100644 index 0000000000..2e053e7199 --- /dev/null +++ b/timesketch/lib/llms/features/manager_test.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for LLM feature manager.""" + +from typing import Any +from timesketch.lib.testlib import BaseTest +from timesketch.lib.llms.features import manager +from timesketch.models.sketch import Sketch + + +class MockSummarizeFeature: + """A mock LLM summarize feature.""" + + NAME = "llm_summarize" + + def generate_prompt(self, _sketch: Sketch, **_kwargs: Any) -> str: + """Mock implementation of generate_prompt.""" + return "Summarize these events." + + def process_response(self, llm_response: str, **kwargs: Any) -> dict[str, Any]: + """Mock implementation of process_response.""" + return {"response": f"Summary: {llm_response}"} + + +class MockNl2qFeature: + """A mock Natural Language to Query feature.""" + + NAME = "nl2q" + + def generate_prompt(self, _sketch: Sketch, **_kwargs: Any) -> str: + """Mock implementation of generate_prompt.""" + return "Convert this question to a query." + + def process_response(self, llm_response: str, **_kwargs: Any) -> dict[str, Any]: + """Mock implementation of process_response.""" + return {"response": f"Query: {llm_response}"} + + +class TestFeatureManager(BaseTest): + """Tests for the functionality of the FeatureManager module.""" + + def setUp(self) -> None: + super().setUp() + manager.FeatureManager.clear_registration() + manager.FeatureManager.register_feature(MockSummarizeFeature) + manager.FeatureManager.register_feature(MockNl2qFeature) + + def tearDown(self) -> None: + manager.FeatureManager.clear_registration() + super().tearDown() + + def test_get_features(self): + """Test that get_features returns the registered features.""" + features = manager.FeatureManager.get_features() + feature_list = list(features) + self.assertIsInstance(feature_list, list) + + found_summarize = any( + feature_name == "llm_summarize" and feature_class == MockSummarizeFeature + for feature_name, feature_class in feature_list + ) + found_nl2q = any( + feature_name == "nl2q" and feature_class == MockNl2qFeature + for feature_name, feature_class in feature_list + ) + self.assertTrue(found_summarize, "LLM Summarize feature not found.") + self.assertTrue(found_nl2q, "NL2Q feature not found.") + + def test_get_feature(self): + """Test retrieval of a feature class from the registry.""" + feature_class = manager.FeatureManager.get_feature("llm_summarize") + self.assertEqual(feature_class, MockSummarizeFeature) + + feature_class = manager.FeatureManager.get_feature("LLM_SUMMARIZE") + self.assertEqual(feature_class, MockSummarizeFeature) + + self.assertRaises( + KeyError, manager.FeatureManager.get_feature, "no_such_feature" + ) + + def test_register_feature(self): + """Test that re-registering an already registered feature raises ValueError.""" + self.assertRaises( + ValueError, manager.FeatureManager.register_feature, MockSummarizeFeature + ) + + def test_get_feature_instance(self): + """Test get_feature_instance creates the correct feature instance.""" + feature_instance = manager.FeatureManager.get_feature_instance("llm_summarize") + self.assertIsInstance(feature_instance, MockSummarizeFeature) + + feature_instance = manager.FeatureManager.get_feature_instance("nl2q") + self.assertIsInstance(feature_instance, MockNl2qFeature) + + self.assertRaises( + KeyError, manager.FeatureManager.get_feature_instance, "no_such_feature" + ) + + def test_feature_methods(self): + """Test that feature methods work correctly.""" + summarize_instance = manager.FeatureManager.get_feature_instance( + "llm_summarize" + ) + nl2q_instance = manager.FeatureManager.get_feature_instance("nl2q") + + sketch = None + + self.assertEqual( + summarize_instance.generate_prompt(sketch), "Summarize these events." + ) + self.assertEqual( + nl2q_instance.generate_prompt(sketch), "Convert this question to a query." + ) + + self.assertEqual( + summarize_instance.process_response("Test events"), + {"response": "Summary: Test events"}, + ) + self.assertEqual( + nl2q_instance.process_response("timestamp:*"), + {"response": "Query: timestamp:*"}, + ) + + def test_clear_registration(self): + """Test clear_registration removes all registered features.""" + self.assertEqual(len(list(manager.FeatureManager.get_features())), 2) + + manager.FeatureManager.clear_registration() + + self.assertEqual(len(list(manager.FeatureManager.get_features())), 0) + self.assertRaises(KeyError, manager.FeatureManager.get_feature, "llm_summarize") From aa267cc7390dbcea3863b93ba2540f6c8a2e24b0 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Wed, 26 Feb 2025 14:51:02 +0000 Subject: [PATCH 12/22] linter fix --- timesketch/lib/llms/features/manager_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/lib/llms/features/manager_test.py b/timesketch/lib/llms/features/manager_test.py index 2e053e7199..83fc7b81b1 100644 --- a/timesketch/lib/llms/features/manager_test.py +++ b/timesketch/lib/llms/features/manager_test.py @@ -28,7 +28,7 @@ def generate_prompt(self, _sketch: Sketch, **_kwargs: Any) -> str: """Mock implementation of generate_prompt.""" return "Summarize these events." - def process_response(self, llm_response: str, **kwargs: Any) -> dict[str, Any]: + def process_response(self, llm_response: str, **_kwargs: Any) -> dict[str, Any]: """Mock implementation of process_response.""" return {"response": f"Summary: {llm_response}"} From bd8d6d20dbd24058952cd85bfddb0a73d08b22f4 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 10:24:33 +0000 Subject: [PATCH 13/22] Automatically load features, add better doc-strings to interface.py --- timesketch/lib/llms/features/__init__.py | 2 + timesketch/lib/llms/features/interface.py | 48 ++++++++++-- timesketch/lib/llms/features/manager.py | 36 ++++++++- timesketch/lib/llms/features/manager_test.py | 80 +++++++++++++++++--- 4 files changed, 147 insertions(+), 19 deletions(-) diff --git a/timesketch/lib/llms/features/__init__.py b/timesketch/lib/llms/features/__init__.py index 6a8e4caaf4..8346fe51c3 100644 --- a/timesketch/lib/llms/features/__init__.py +++ b/timesketch/lib/llms/features/__init__.py @@ -14,3 +14,5 @@ """LLM features for Timesketch.""" from timesketch.lib.llms.features import manager + +manager.FeatureManager.load_llm_features() diff --git a/timesketch/lib/llms/features/interface.py b/timesketch/lib/llms/features/interface.py index 10317fe014..ba3d1827b2 100644 --- a/timesketch/lib/llms/features/interface.py +++ b/timesketch/lib/llms/features/interface.py @@ -19,7 +19,27 @@ class LLMFeatureInterface(ABC): - """Interface for LLM features.""" + """Interface for LLM features. + + This abstract class defines the required methods and attributes for implementing + an LLM-powered feature in Timesketch. Features must override the NAME constant + and implement the abstract methods. + + Attributes: + NAME: String identifier for the feature. Must be overridden in subclasses. + RESPONSE_SCHEMA: Optional JSON schema that defines the expected format of + the LLM response. When defined, this schema will be passed to the LLM + provider to enforce structured outputs matching the defined format. + For example: + + { + "type": "object", + "properties": {"summary": {"type": "string"}}, + "required": ["summary"], + } + + If None, the LLM will return unstructured text. + """ NAME: str = "llm_feature_interface" # Must be overridden in subclasses RESPONSE_SCHEMA: Optional[dict[str, Any]] = None @@ -39,15 +59,29 @@ def generate_prompt(self, sketch: Sketch, **kwargs: Any) -> str: @abstractmethod def process_response(self, llm_response: str, **kwargs: Any) -> dict[str, Any]: - """Processes the raw LLM response. + """Processes the LLM response and formats it for API consumption. + + This method takes the response from the LLM provider and transforms it into + a structured format to be returned to the frontend through the API. The + response handling varies depending on whether RESPONSE_SCHEMA is defined: + + - If RESPONSE_SCHEMA is None: Typically receives a string response + - If RESPONSE_SCHEMA is defined: Typically receives a structured dict + + The returned dictionary defines the data contract with the frontend, which will + use these fields to render the appropriate UI elements. Args: - llm_response: The raw string response from the LLM provider. - kwargs: Feature-specific arguments. + llm_response: The response from the LLM provider. This may be a + string or a structured dict depending on RESPONSE_SCHEMA. + **kwargs: Additional data needed for processing, which may include: + - sketch_id: The ID of the sketch + - sketch: The Sketch object Returns: - A dictionary containing the processed response data, suitable for - returning from the API. Must include a "response" key with the - main result, and can optionally include other metadata. + A dictionary that will be JSON-serialized and returned through the API. + This dictionary defines the data contract with the frontend and must include + all fields that the frontend expects to render. Example for NL2Q: + - {"name": "AI generated search query", "query_string": "...", "error": null} """ raise NotImplementedError() diff --git a/timesketch/lib/llms/features/manager.py b/timesketch/lib/llms/features/manager.py index 70bfac4836..67010a50cf 100644 --- a/timesketch/lib/llms/features/manager.py +++ b/timesketch/lib/llms/features/manager.py @@ -13,6 +13,10 @@ # limitations under the License. """Manager for LLM features.""" +import os +import importlib +import inspect +import pkgutil import logging from timesketch.lib.llms.features.interface import LLMFeatureInterface @@ -24,6 +28,37 @@ class FeatureManager: _feature_registry = {} + @classmethod + def load_llm_features(cls): + """Dynamically load and register all LLM features.""" + features_path = os.path.dirname(os.path.abspath(__file__)) + cls.clear_registration() + + for _, module_name, _ in pkgutil.iter_modules([features_path]): + if module_name in ["interface", "manager"] or module_name.endswith("_test"): + continue + try: + module = importlib.import_module( + f"timesketch.lib.llms.features.{module_name}" + ) + for _, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, LLMFeatureInterface) + and obj != LLMFeatureInterface + ): + try: + cls.register_feature(obj) + except ValueError as e: + logger.debug("Failed to register feature: %s", str(e)) + + except (ImportError, AttributeError) as e: + logger.error( + "Error loading LLM feature module %s: %s", module_name, str(e) + ) + + logger.debug("Loaded %d LLM features", len(cls._feature_registry)) + @classmethod def register_feature(cls, feature_class: type[LLMFeatureInterface]): """Register an LLM feature class.""" @@ -31,7 +66,6 @@ def register_feature(cls, feature_class: type[LLMFeatureInterface]): if feature_name in cls._feature_registry: raise ValueError(f"LLM Feature {feature_class.NAME} already registered") cls._feature_registry[feature_name] = feature_class - # Optional: Add logging here @classmethod def get_feature(cls, feature_name: str) -> type[LLMFeatureInterface]: diff --git a/timesketch/lib/llms/features/manager_test.py b/timesketch/lib/llms/features/manager_test.py index 83fc7b81b1..7e5c0dd49b 100644 --- a/timesketch/lib/llms/features/manager_test.py +++ b/timesketch/lib/llms/features/manager_test.py @@ -13,10 +13,13 @@ # limitations under the License. """Tests for LLM feature manager.""" +import mock +import types from typing import Any from timesketch.lib.testlib import BaseTest from timesketch.lib.llms.features import manager from timesketch.models.sketch import Sketch +from timesketch.lib.llms.features.interface import LLMFeatureInterface class MockSummarizeFeature: @@ -25,28 +28,48 @@ class MockSummarizeFeature: NAME = "llm_summarize" def generate_prompt(self, _sketch: Sketch, **_kwargs: Any) -> str: - """Mock implementation of generate_prompt.""" + """Mocks implementation of generate_prompt.""" return "Summarize these events." def process_response(self, llm_response: str, **_kwargs: Any) -> dict[str, Any]: - """Mock implementation of process_response.""" + """Mocks implementation of process_response.""" return {"response": f"Summary: {llm_response}"} -class MockNl2qFeature: +class MockNl2qFeature(LLMFeatureInterface): """A mock Natural Language to Query feature.""" NAME = "nl2q" def generate_prompt(self, _sketch: Sketch, **_kwargs: Any) -> str: - """Mock implementation of generate_prompt.""" + """Mocks implementation of generate_prompt.""" return "Convert this question to a query." def process_response(self, llm_response: str, **_kwargs: Any) -> dict[str, Any]: - """Mock implementation of process_response.""" + """Mocks implementation of process_response.""" return {"response": f"Query: {llm_response}"} +class MockFeature(LLMFeatureInterface): + NAME = "some_feature" + + def generate_prompt(self, *args: Any, **kwargs: Any) -> str: + return "some prompt" + + def process_response(self, *args: Any, **kwargs: Any) -> dict: + return {"response": "some response"} + + +class DuplicateNl2qFeature(LLMFeatureInterface): + NAME = "nl2q" + + def generate_prompt(self, *args: Any, **kwargs: Any) -> str: + return "duplicate prompt" + + def process_response(self, *args: Any, **kwargs: Any) -> dict: + return {"response": "duplicate response"} + + class TestFeatureManager(BaseTest): """Tests for the functionality of the FeatureManager module.""" @@ -61,7 +84,7 @@ def tearDown(self) -> None: super().tearDown() def test_get_features(self): - """Test that get_features returns the registered features.""" + """Tests that get_features returns the registered features.""" features = manager.FeatureManager.get_features() feature_list = list(features) self.assertIsInstance(feature_list, list) @@ -78,7 +101,7 @@ def test_get_features(self): self.assertTrue(found_nl2q, "NL2Q feature not found.") def test_get_feature(self): - """Test retrieval of a feature class from the registry.""" + """Tests retrieval of a feature class from the registry.""" feature_class = manager.FeatureManager.get_feature("llm_summarize") self.assertEqual(feature_class, MockSummarizeFeature) @@ -90,13 +113,13 @@ def test_get_feature(self): ) def test_register_feature(self): - """Test that re-registering an already registered feature raises ValueError.""" + """Tests that re-registering an already registered feature raises ValueError.""" self.assertRaises( ValueError, manager.FeatureManager.register_feature, MockSummarizeFeature ) def test_get_feature_instance(self): - """Test get_feature_instance creates the correct feature instance.""" + """Tests that get_feature_instance creates the correct feature instance.""" feature_instance = manager.FeatureManager.get_feature_instance("llm_summarize") self.assertIsInstance(feature_instance, MockSummarizeFeature) @@ -108,7 +131,7 @@ def test_get_feature_instance(self): ) def test_feature_methods(self): - """Test that feature methods work correctly.""" + """Tests that feature methods work correctly.""" summarize_instance = manager.FeatureManager.get_feature_instance( "llm_summarize" ) @@ -133,10 +156,45 @@ def test_feature_methods(self): ) def test_clear_registration(self): - """Test clear_registration removes all registered features.""" + """Tests that clear_registration removes all registered features.""" self.assertEqual(len(list(manager.FeatureManager.get_features())), 2) manager.FeatureManager.clear_registration() self.assertEqual(len(list(manager.FeatureManager.get_features())), 0) self.assertRaises(KeyError, manager.FeatureManager.get_feature, "llm_summarize") + + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.iter_modules", return_value=[(None, "nl2q", False)]) + def test_load_llm_feature(self, _, mock_import_module) -> None: + """Tests that load_llm_feature loads the expected features.""" + mock_module = types.ModuleType("mock_module") + setattr(mock_module, "MockNl2qFeature", MockNl2qFeature) + mock_import_module.return_value = mock_module + + manager.FeatureManager.load_llm_features() + features = list(manager.FeatureManager.get_features()) + self.assertEqual(len(features), 1) + registered_name, registered_class = features[0] + self.assertEqual(registered_name, "nl2q") + self.assertEqual(registered_class, MockNl2qFeature) + mock_import_module.assert_called_with("timesketch.lib.llms.features.nl2q") + + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.iter_modules", return_value=[(None, "nl2q", False)]) + def test_load_llm_feature_duplicate(self, _, mock_import_module) -> None: + """Tests that load_llm_feature handles registration of duplciate features.""" + dummy_module = types.ModuleType("dummy_module") + setattr(dummy_module, "MockNl2qFeature", MockNl2qFeature) + setattr(dummy_module, "DuplicateNl2qFeature", DuplicateNl2qFeature) + mock_import_module.return_value = dummy_module + + with self.assertLogs("timesketch.llm.manager", level="WARNING") as log_cm: + manager.FeatureManager.load_llm_features() + features = list(manager.FeatureManager.get_features()) + self.assertEqual(len(features), 1) + registered_name, _ = features[0] + self.assertEqual(registered_name, "nl2q") + self.assertTrue( + any("already registered" in message for message in log_cm.output) + ) From 290afc68ae51393420b29ac1e0a80df32f960c8f Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 10:27:53 +0000 Subject: [PATCH 14/22] linter fix --- timesketch/lib/llms/features/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/lib/llms/features/interface.py b/timesketch/lib/llms/features/interface.py index ba3d1827b2..09b76e0ec3 100644 --- a/timesketch/lib/llms/features/interface.py +++ b/timesketch/lib/llms/features/interface.py @@ -82,6 +82,6 @@ def process_response(self, llm_response: str, **kwargs: Any) -> dict[str, Any]: A dictionary that will be JSON-serialized and returned through the API. This dictionary defines the data contract with the frontend and must include all fields that the frontend expects to render. Example for NL2Q: - - {"name": "AI generated search query", "query_string": "...", "error": null} + - {"name": "AI generated search query","query_string": "...","error":null} """ raise NotImplementedError() From fb9b668783c5f280567bacc34dfac7826e7aec62 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 10:31:48 +0000 Subject: [PATCH 15/22] linter fixes --- timesketch/lib/llms/features/manager_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timesketch/lib/llms/features/manager_test.py b/timesketch/lib/llms/features/manager_test.py index 7e5c0dd49b..46518ac328 100644 --- a/timesketch/lib/llms/features/manager_test.py +++ b/timesketch/lib/llms/features/manager_test.py @@ -13,9 +13,9 @@ # limitations under the License. """Tests for LLM feature manager.""" -import mock import types from typing import Any +import mock from timesketch.lib.testlib import BaseTest from timesketch.lib.llms.features import manager from timesketch.models.sketch import Sketch @@ -53,20 +53,20 @@ def process_response(self, llm_response: str, **_kwargs: Any) -> dict[str, Any]: class MockFeature(LLMFeatureInterface): NAME = "some_feature" - def generate_prompt(self, *args: Any, **kwargs: Any) -> str: + def generate_prompt(self, *_args: Any, **_kwargs: Any) -> str: return "some prompt" - def process_response(self, *args: Any, **kwargs: Any) -> dict: + def process_response(self, *_args: Any, **_kwargs: Any) -> dict: return {"response": "some response"} class DuplicateNl2qFeature(LLMFeatureInterface): NAME = "nl2q" - def generate_prompt(self, *args: Any, **kwargs: Any) -> str: + def generate_prompt(self, *_args: Any, **_kwargs: Any) -> str: return "duplicate prompt" - def process_response(self, *args: Any, **kwargs: Any) -> dict: + def process_response(self, *_args: Any, **_kwargs: Any) -> dict: return {"response": "duplicate response"} From 0858b7f83cf5a039ad6ab3b448cec02534c83604 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 10:50:51 +0000 Subject: [PATCH 16/22] linter fixes --- timesketch/lib/llms/features/manager_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/lib/llms/features/manager_test.py b/timesketch/lib/llms/features/manager_test.py index 46518ac328..a74c46cef0 100644 --- a/timesketch/lib/llms/features/manager_test.py +++ b/timesketch/lib/llms/features/manager_test.py @@ -189,7 +189,7 @@ def test_load_llm_feature_duplicate(self, _, mock_import_module) -> None: setattr(dummy_module, "DuplicateNl2qFeature", DuplicateNl2qFeature) mock_import_module.return_value = dummy_module - with self.assertLogs("timesketch.llm.manager", level="WARNING") as log_cm: + with self.assertLogs("timesketch.llm.manager", level="DEBUG") as log_cm: manager.FeatureManager.load_llm_features() features = list(manager.FeatureManager.get_features()) self.assertEqual(len(features), 1) From 1bcd2b1889c344c47e295f4ef7acca6833cdb18d Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 13:13:57 +0000 Subject: [PATCH 17/22] Introduce LLMResource API method, tests, and add it as a method for the frontend --- timesketch/api/v1/resources/llm.py | 239 ++++++++++++++++++ timesketch/api/v1/resources_test.py | 131 +++++++++- timesketch/api/v1/routes.py | 2 + .../frontend-ng/src/utils/RestApiClient.js | 6 + 4 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 timesketch/api/v1/resources/llm.py diff --git a/timesketch/api/v1/resources/llm.py b/timesketch/api/v1/resources/llm.py new file mode 100644 index 0000000000..c01eab48dd --- /dev/null +++ b/timesketch/api/v1/resources/llm.py @@ -0,0 +1,239 @@ +# Copyright 2025 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Timesketch API endpoint for interacting with LLM features.""" +import logging +import multiprocessing +import multiprocessing.managers +import time + +import prometheus_client +from flask import request, abort, jsonify +from flask_login import login_required, current_user +from flask_restful import Resource + +from timesketch.api.v1 import resources +from timesketch.lib import definitions, utils +from timesketch.lib.definitions import METRICS_NAMESPACE +from timesketch.lib.llms.providers import manager as llm_manager +from timesketch.lib.llms.features import manager as feature_manager +from timesketch.models.sketch import Sketch + +logger = logging.getLogger("timesketch.api.llm") + + +class LLMResource(resources.ResourceMixin, Resource): + """Resource to interact with LLMs.""" + + METRICS = { + "llm_requests_total": prometheus_client.Counter( + "llm_requests_total", + "Total number of LLM requests received", + ["sketch_id", "feature"], + namespace=METRICS_NAMESPACE, + ), + "llm_errors_total": prometheus_client.Counter( + "llm_errors_total", + "Total number of errors during LLM processing", + ["sketch_id", "feature", "error_type"], + namespace=METRICS_NAMESPACE, + ), + "llm_duration_seconds": prometheus_client.Summary( + "llm_duration_seconds", + "Time taken to process an LLM request (in seconds)", + ["sketch_id", "feature"], + namespace=METRICS_NAMESPACE, + ), + } + + _LLM_TIMEOUT_WAIT_SECONDS = 30 + + @login_required + def post(self, sketch_id: int): + """Handles POST requests to the resource.""" + start_time = time.time() + sketch = self._validate_sketch(sketch_id) + form = self._validate_request_data() + feature = self._get_feature(form.get("feature")) + + self._increment_request_metric(sketch_id, feature.NAME) + + timeline_ids = self._validate_indices(sketch, form.get("filter", {})) + prompt = self._generate_prompt(feature, sketch, form, timeline_ids) + response = self._execute_llm_call(feature, prompt, sketch_id) + result = self._process_llm_response( + feature, response, sketch, form, timeline_ids + ) + + self._record_duration(sketch_id, feature.NAME, start_time) + return jsonify(result) + + def _validate_sketch(self, sketch_id: int) -> Sketch: + """Validates sketch existence and user permissions.""" + sketch = Sketch.get_with_acl(sketch_id) + if not sketch: + abort( + definitions.HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID." + ) + if not sketch.has_permission(current_user, "read"): + abort( + definitions.HTTP_STATUS_CODE_FORBIDDEN, + "User does not have read access to the sketch.", + ) + return sketch + + def _validate_request_data(self) -> dict: + """Validates the presence of request JSON data.""" + form = request.json + if not form: + abort( + definitions.HTTP_STATUS_CODE_BAD_REQUEST, + "The POST request requires data", + ) + return form + + def _get_feature(self, feature_name: str) -> feature_manager.LLMFeatureInterface: + """Retrieves and validates the requested LLM feature.""" + if not feature_name: + abort( + definitions.HTTP_STATUS_CODE_BAD_REQUEST, + "The 'feature' parameter is required.", + ) + try: + return feature_manager.FeatureManager.get_feature_instance(feature_name) + except KeyError: + abort( + definitions.HTTP_STATUS_CODE_BAD_REQUEST, + f"Invalid LLM feature: {feature_name}", + ) + + def _validate_indices(self, sketch: Sketch, query_filter: dict) -> list: + """Extracts and validates timeline IDs from the query filter for a sketch.""" + all_indices = list({t.searchindex.index_name for t in sketch.timelines}) + indices = query_filter.get("indices", all_indices) + if "_all" in indices: + indices = all_indices + indices, timeline_ids = utils.get_validated_indices(indices, sketch) + if not indices: + abort( + definitions.HTTP_STATUS_CODE_BAD_REQUEST, + "No valid search indices were found.", + ) + return timeline_ids + + def _generate_prompt( + self, + feature: feature_manager.LLMFeatureInterface, + sketch: Sketch, + form: dict, + timeline_ids: list, + ) -> str: + """Generates the LLM prompt based on the feature and request data.""" + try: + return feature.generate_prompt( + sketch, form=form, datastore=self.datastore, timeline_ids=timeline_ids + ) + except ValueError as e: + abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, str(e)) + + def _execute_llm_call( + self, feature: feature_manager.LLMFeatureInterface, prompt: str, sketch_id: int + ) -> dict: + """Executes the LLM call with a timeout using multiprocessing.""" + with multiprocessing.Manager() as manager: + shared_response = manager.dict() + process = multiprocessing.Process( + target=self._get_content_with_timeout, + args=(feature, prompt, shared_response), + ) + process.start() + process.join(timeout=self._LLM_TIMEOUT_WAIT_SECONDS) + + if process.is_alive(): + logger.warning( + "LLM call timed out after %d seconds.", + self._LLM_TIMEOUT_WAIT_SECONDS, + ) + process.terminate() + process.join() + self.METRICS["llm_errors_total"].labels( + sketch_id=str(sketch_id), feature=feature.NAME, error_type="timeout" + ).inc() + abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, "LLM call timed out.") + + response = dict(shared_response) + if "error" in response: + self.METRICS["llm_errors_total"].labels( + sketch_id=str(sketch_id), + feature=feature.NAME, + error_type="llm_api_error", + ).inc() + abort( + definitions.HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, + "Error during LLM processing.", + ) + return response["response"] + + def _process_llm_response( + self, feature, response: dict, sketch: Sketch, form: dict, timeline_ids: list + ) -> dict: + """Processes the LLM response into the final result.""" + try: + return feature.process_response( + llm_response=response, + form=form, + sketch_id=sketch.id, + datastore=self.datastore, + sketch=sketch, + timeline_ids=timeline_ids, + ) + except ValueError as e: + self.METRICS["llm_errors_total"].labels( + sketch_id=str(sketch.id), + feature=feature.NAME, + error_type="response_processing", + ).inc() + abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, str(e)) + + def _increment_request_metric(self, sketch_id: int, feature_name: str) -> None: + """Increments the request counter metric.""" + self.METRICS["llm_requests_total"].labels( + sketch_id=str(sketch_id), feature=feature_name + ).inc() + + def _record_duration( + self, sketch_id: int, feature_name: str, start_time: float + ) -> None: + """Records the duration of the request.""" + duration = time.time() - start_time + self.METRICS["llm_duration_seconds"].labels( + sketch_id=str(sketch_id), feature=feature_name + ).observe(duration) + + def _get_content_with_timeout( + self, + feature: feature_manager.LLMFeatureInterface, + prompt: str, + shared_response: multiprocessing.managers.DictProxy, + ) -> None: + """Send a prompt to the LLM and get a response within a process.""" + try: + llm = llm_manager.LLMManager.create_provider(feature_name=feature.NAME) + response_schema = ( + feature.RESPONSE_SCHEMA if hasattr(feature, "RESPONSE_SCHEMA") else None + ) + response = llm.generate(prompt, response_schema=response_schema) + shared_response.update({"response": response}) + except Exception as e: + logger.error("Error in LLM call within process: %s", e, exc_info=True) + shared_response.update({"error": str(e)}) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 14dc3e8bfa..222641e45c 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -33,7 +33,6 @@ from timesketch.models.sketch import InvestigativeQuestion from timesketch.models.sketch import InvestigativeQuestionApproach from timesketch.models.sketch import Facet - from timesketch.api.v1.resources import ResourceMixin @@ -1692,3 +1691,133 @@ def test_llm_summarize_with_events(self, mock_create_provider): self.assertEqual(response.status_code, 200) response_data = json.loads(response.get_data(as_text=True)) self.assertEqual(response_data.get("summary"), "Mock summary from LLM") + + +@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) +class LLMResourceTest(BaseTest): + """Test LLMResource.""" + + resource_url = "/api/v1/sketches/1/llm/" + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + @mock.patch( + "timesketch.lib.llms.features.manager.FeatureManager.get_feature_instance" + ) + @mock.patch("timesketch.lib.utils.get_validated_indices") + @mock.patch("timesketch.api.v1.resources.llm.LLMResource._execute_llm_call") + def test_post_success( + self, + mock_execute_llm, + mock_get_validated_indices, + mock_get_feature, + mock_get_with_acl, + ): + """Test a successful POST request to the LLM endpoint.""" + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = True + mock_sketch.id = 1 + mock_get_with_acl.return_value = mock_sketch + + mock_feature = mock.MagicMock() + mock_feature.NAME = "test_feature" + mock_feature.generate_prompt.return_value = "test prompt" + mock_feature.process_response.return_value = {"result": "test result"} + mock_get_feature.return_value = mock_feature + + mock_get_validated_indices.return_value = (["index1"], [1]) + mock_execute_llm.return_value = {"response": "mock response"} + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "test_feature", "filter": {}}), + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK) + response_data = json.loads(response.get_data(as_text=True)) + self.assertEqual(response_data, {"result": "test result"}) + + def test_post_missing_data(self): + """Test POST request with missing data.""" + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"some_param": "some_value"}), + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("The 'feature' parameter is required", response_data["message"]) + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + def test_post_missing_feature(self, mock_get_with_acl): + """Test POST request with no feature parameter.""" + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = True + mock_get_with_acl.return_value = mock_sketch + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"filter": {}}), # No 'feature' key + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("The 'feature' parameter is required", response_data["message"]) + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + def test_post_invalid_sketch(self, mock_get_with_acl): + """Test POST request with an invalid sketch ID.""" + mock_get_with_acl.return_value = None + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "test_feature", "filter": {}}), + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_NOT_FOUND) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("No sketch found with this ID", response_data["message"]) + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + def test_post_no_permission(self, mock_get_with_acl): + """Test POST request when user lacks read permission.""" + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = False + mock_get_with_acl.return_value = mock_sketch + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "test_feature", "filter": {}}), + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn( + "User does not have read access to the sketch", response_data["message"] + ) + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + @mock.patch( + "timesketch.lib.llms.features.manager.FeatureManager.get_feature_instance" + ) + def test_post_invalid_feature(self, mock_get_feature, mock_get_with_acl): + """Test POST request with an invalid feature name.""" + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = True + mock_get_with_acl.return_value = mock_sketch + + mock_get_feature.side_effect = KeyError("Invalid feature") + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "invalid_feature", "filter": {}}), + content_type="application/json", + ) + self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("Invalid LLM feature: invalid_feature", response_data["message"]) diff --git a/timesketch/api/v1/routes.py b/timesketch/api/v1/routes.py index 48ecf6f05a..5bc249ebc5 100644 --- a/timesketch/api/v1/routes.py +++ b/timesketch/api/v1/routes.py @@ -78,6 +78,7 @@ from .resources.unfurl import UnfurlResource from .resources.nl2q import Nl2qResource from .resources.llm_summarize import LLMSummarizeResource +from .resources.llm import LLMResource from .resources.settings import SystemSettingsResource from .resources.scenarios import ScenarioTemplateListResource @@ -204,6 +205,7 @@ (UnfurlResource, "/unfurl/"), (Nl2qResource, "/sketches//nl2q/"), (LLMSummarizeResource, "/sketches//events/summary/"), + (LLMResource, "/sketches//llm/"), (SystemSettingsResource, "/settings/"), # Scenario templates (ScenarioTemplateListResource, "/scenarios/"), diff --git a/timesketch/frontend-ng/src/utils/RestApiClient.js b/timesketch/frontend-ng/src/utils/RestApiClient.js index 36114aef1a..86416ebd33 100644 --- a/timesketch/frontend-ng/src/utils/RestApiClient.js +++ b/timesketch/frontend-ng/src/utils/RestApiClient.js @@ -528,4 +528,10 @@ export default { getEventSummary(sketchId, formData) { return RestApiClient.post('/sketches/' + sketchId + '/events/summary/', formData) }, + llmRequest(sketchId, featureName, formData) { + formData = formData || {} + formData.feature = featureName + + return RestApiClient.post(`/sketches/${sketchId}/llm/`, formData) + } } From f379b0e96b4fd2df6a2481ac1a55f385567fb232 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 13:26:13 +0000 Subject: [PATCH 18/22] linter fix --- timesketch/api/v1/resources/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/api/v1/resources/llm.py b/timesketch/api/v1/resources/llm.py index c01eab48dd..edaa3c9691 100644 --- a/timesketch/api/v1/resources/llm.py +++ b/timesketch/api/v1/resources/llm.py @@ -234,6 +234,6 @@ def _get_content_with_timeout( ) response = llm.generate(prompt, response_schema=response_schema) shared_response.update({"response": response}) - except Exception as e: + except Exception as e: # pylint: disable=broad-except logger.error("Error in LLM call within process: %s", e, exc_info=True) shared_response.update({"error": str(e)}) From 2e669d0619f893f54d04a94825a4c28992a1bfe9 Mon Sep 17 00:00:00 2001 From: itsmvd Date: Thu, 27 Feb 2025 13:28:10 +0000 Subject: [PATCH 19/22] linter fix --- timesketch/api/v1/resources/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/api/v1/resources/llm.py b/timesketch/api/v1/resources/llm.py index edaa3c9691..d9343144a7 100644 --- a/timesketch/api/v1/resources/llm.py +++ b/timesketch/api/v1/resources/llm.py @@ -234,6 +234,6 @@ def _get_content_with_timeout( ) response = llm.generate(prompt, response_schema=response_schema) shared_response.update({"response": response}) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except logger.error("Error in LLM call within process: %s", e, exc_info=True) shared_response.update({"error": str(e)}) From 1e58a282140dced2c56024340948f0203764024b Mon Sep 17 00:00:00 2001 From: itsmvd Date: Fri, 28 Feb 2025 10:47:38 +0000 Subject: [PATCH 20/22] Address comments from review --- timesketch/api/v1/resources/llm.py | 168 +++++++++++++++++++++++----- timesketch/api/v1/resources_test.py | 83 ++++++++++++++ 2 files changed, 226 insertions(+), 25 deletions(-) diff --git a/timesketch/api/v1/resources/llm.py b/timesketch/api/v1/resources/llm.py index d9343144a7..e776502aeb 100644 --- a/timesketch/api/v1/resources/llm.py +++ b/timesketch/api/v1/resources/llm.py @@ -16,12 +16,10 @@ import multiprocessing import multiprocessing.managers import time - import prometheus_client -from flask import request, abort, jsonify +from flask import request, abort, jsonify, Response from flask_login import login_required, current_user from flask_restful import Resource - from timesketch.api.v1 import resources from timesketch.lib import definitions, utils from timesketch.lib.definitions import METRICS_NAMESPACE @@ -33,7 +31,12 @@ class LLMResource(resources.ResourceMixin, Resource): - """Resource to interact with LLMs.""" + """Resource to interact with LLMs. + + This class provides an API endpoint for accessing and utilizing Large Language + Model features within Timesketch. It handles request validation, processing, + and response handling, while also monitoring performance metrics. + """ METRICS = { "llm_requests_total": prometheus_client.Counter( @@ -55,31 +58,52 @@ class LLMResource(resources.ResourceMixin, Resource): namespace=METRICS_NAMESPACE, ), } - + # TODO(itsmvd): Make this configurable _LLM_TIMEOUT_WAIT_SECONDS = 30 @login_required - def post(self, sketch_id: int): - """Handles POST requests to the resource.""" + def post(self, sketch_id: int) -> Response: + """Handles POST requests to the resource. + + Processes LLM requests, validates inputs, generates prompts, + executes LLM calls, and returns the processed results. + + Args: + sketch_id: The ID of the sketch to process. + + Returns: + A Flask JSON response containing the processed LLM result. + + Raises: + HTTP exceptions for various error conditions. + """ start_time = time.time() sketch = self._validate_sketch(sketch_id) form = self._validate_request_data() feature = self._get_feature(form.get("feature")) - self._increment_request_metric(sketch_id, feature.NAME) - timeline_ids = self._validate_indices(sketch, form.get("filter", {})) prompt = self._generate_prompt(feature, sketch, form, timeline_ids) response = self._execute_llm_call(feature, prompt, sketch_id) result = self._process_llm_response( feature, response, sketch, form, timeline_ids ) - self._record_duration(sketch_id, feature.NAME, start_time) return jsonify(result) def _validate_sketch(self, sketch_id: int) -> Sketch: - """Validates sketch existence and user permissions.""" + """Validates sketch existence and user permissions. + + Args: + sketch_id: The ID of the sketch to validate. + + Returns: + The validated Sketch object. + + Raises: + HTTP 404: If the sketch doesn't exist. + HTTP 403: If the user doesn't have read access to the sketch. + """ sketch = Sketch.get_with_acl(sketch_id) if not sketch: abort( @@ -93,7 +117,14 @@ def _validate_sketch(self, sketch_id: int) -> Sketch: return sketch def _validate_request_data(self) -> dict: - """Validates the presence of request JSON data.""" + """Validates the presence of request JSON data. + + Returns: + The validated request data as a dictionary. + + Raises: + HTTP 400: If no JSON data is provided in the request. + """ form = request.json if not form: abort( @@ -103,7 +134,17 @@ def _validate_request_data(self) -> dict: return form def _get_feature(self, feature_name: str) -> feature_manager.LLMFeatureInterface: - """Retrieves and validates the requested LLM feature.""" + """Retrieves and validates the requested LLM feature. + + Args: + feature_name: The name of the LLM feature to retrieve. + + Returns: + An instance of the requested LLM feature. + + Raises: + HTTP 400: If feature_name is not provided or is invalid. + """ if not feature_name: abort( definitions.HTTP_STATUS_CODE_BAD_REQUEST, @@ -118,7 +159,18 @@ def _get_feature(self, feature_name: str) -> feature_manager.LLMFeatureInterface ) def _validate_indices(self, sketch: Sketch, query_filter: dict) -> list: - """Extracts and validates timeline IDs from the query filter for a sketch.""" + """Extracts and validates timeline IDs from the query filter for a sketch. + + Args: + sketch: The Sketch object to validate indices for. + query_filter: A dictionary containing filter parameters. + + Returns: + A list of validated timeline IDs. + + Raises: + HTTP 400: If no valid search indices are found. + """ all_indices = list({t.searchindex.index_name for t in sketch.timelines}) indices = query_filter.get("indices", all_indices) if "_all" in indices: @@ -138,7 +190,20 @@ def _generate_prompt( form: dict, timeline_ids: list, ) -> str: - """Generates the LLM prompt based on the feature and request data.""" + """Generates the LLM prompt based on the feature and request data. + + Args: + feature: The LLM feature instance to use. + sketch: The Sketch object. + form: The request form data. + timeline_ids: A list of validated timeline IDs. + + Returns: + The generated prompt string for the LLM. + + Raises: + HTTP 400: If prompt generation fails. + """ try: return feature.generate_prompt( sketch, form=form, datastore=self.datastore, timeline_ids=timeline_ids @@ -149,7 +214,20 @@ def _generate_prompt( def _execute_llm_call( self, feature: feature_manager.LLMFeatureInterface, prompt: str, sketch_id: int ) -> dict: - """Executes the LLM call with a timeout using multiprocessing.""" + """Executes the LLM call with a timeout using multiprocessing. + + Args: + feature: The LLM feature instance to use. + prompt: The generated prompt to send to the LLM. + sketch_id: The ID of the sketch being processed. + + Returns: + The LLM response as a dictionary. + + Raises: + HTTP 400: If the LLM call times out. + HTTP 500: If an error occurs during LLM processing. + """ with multiprocessing.Manager() as manager: shared_response = manager.dict() process = multiprocessing.Process( @@ -158,7 +236,6 @@ def _execute_llm_call( ) process.start() process.join(timeout=self._LLM_TIMEOUT_WAIT_SECONDS) - if process.is_alive(): logger.warning( "LLM call timed out after %d seconds.", @@ -169,8 +246,11 @@ def _execute_llm_call( self.METRICS["llm_errors_total"].labels( sketch_id=str(sketch_id), feature=feature.NAME, error_type="timeout" ).inc() - abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, "LLM call timed out.") - + abort( + definitions.HTTP_STATUS_CODE_BAD_REQUEST, + "LLM call timed out, please try again. " + "If this issue persists, contact your administrator.", + ) response = dict(shared_response) if "error" in response: self.METRICS["llm_errors_total"].labels( @@ -180,14 +260,33 @@ def _execute_llm_call( ).inc() abort( definitions.HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, - "Error during LLM processing.", + f"Error during LLM processing: {response['error']}", ) return response["response"] def _process_llm_response( - self, feature, response: dict, sketch: Sketch, form: dict, timeline_ids: list + self, + feature: feature_manager.LLMFeatureInterface, + response: dict, + sketch: Sketch, + form: dict, + timeline_ids: list, ) -> dict: - """Processes the LLM response into the final result.""" + """Processes the LLM response into the final result. + + Args: + feature: The LLM feature instance used. + response: The raw LLM response. + sketch: The Sketch object. + form: The request form data. + timeline_ids: A list of validated timeline IDs. + + Returns: + The processed LLM response as a dictionary. + + Raises: + HTTP 400: If response processing fails. + """ try: return feature.process_response( llm_response=response, @@ -206,7 +305,12 @@ def _process_llm_response( abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, str(e)) def _increment_request_metric(self, sketch_id: int, feature_name: str) -> None: - """Increments the request counter metric.""" + """Increments the request counter metric. + + Args: + sketch_id: The ID of the sketch being processed. + feature_name: The name of the LLM feature being used. + """ self.METRICS["llm_requests_total"].labels( sketch_id=str(sketch_id), feature=feature_name ).inc() @@ -214,7 +318,13 @@ def _increment_request_metric(self, sketch_id: int, feature_name: str) -> None: def _record_duration( self, sketch_id: int, feature_name: str, start_time: float ) -> None: - """Records the duration of the request.""" + """Records the duration of the request. + + Args: + sketch_id: The ID of the sketch being processed. + feature_name: The name of the LLM feature being used. + start_time: The timestamp when the request started. + """ duration = time.time() - start_time self.METRICS["llm_duration_seconds"].labels( sketch_id=str(sketch_id), feature=feature_name @@ -226,7 +336,15 @@ def _get_content_with_timeout( prompt: str, shared_response: multiprocessing.managers.DictProxy, ) -> None: - """Send a prompt to the LLM and get a response within a process.""" + """Send a prompt to the LLM and get a response within a process. + + This method is executed in a separate process to allow for timeout control. + + Args: + feature: The LLM feature instance to use. + prompt: The generated prompt to send to the LLM. + shared_response: A managed dictionary to store the response or error. + """ try: llm = llm_manager.LLMManager.create_provider(feature_name=feature.NAME) response_schema = ( diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 222641e45c..64f3b8f045 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1821,3 +1821,86 @@ def test_post_invalid_feature(self, mock_get_feature, mock_get_with_acl): self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) response_data = json.loads(response.get_data(as_text=True)) self.assertIn("Invalid LLM feature: invalid_feature", response_data["message"]) + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + @mock.patch( + "timesketch.lib.llms.features.manager.FeatureManager.get_feature_instance" + ) + @mock.patch("timesketch.lib.utils.get_validated_indices") + def test_post_prompt_generation_error( + self, + mock_get_validated_indices, + mock_get_feature, + mock_get_with_acl, + ): + """Test handling of errors during prompt generation.""" + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = True + mock_sketch.id = 1 + mock_get_with_acl.return_value = mock_sketch + + mock_feature = mock.MagicMock() + mock_feature.NAME = "test_feature" + mock_feature.generate_prompt.side_effect = ValueError( + "Prompt generation failed" + ) + mock_get_feature.return_value = mock_feature + + mock_get_validated_indices.return_value = (["index1"], [1]) + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "test_feature", "filter": {}}), + content_type="application/json", + ) + + self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("Prompt generation failed", response_data["message"]) + + mock_feature.generate_prompt.assert_called_once() + + @mock.patch("timesketch.models.sketch.Sketch.get_with_acl") + @mock.patch( + "timesketch.lib.llms.features.manager.FeatureManager.get_feature_instance" + ) + @mock.patch("timesketch.lib.utils.get_validated_indices") + @mock.patch("multiprocessing.Process") + def test_post_llm_execution_timeout( + self, + mock_process, + mock_get_validated_indices, + mock_get_feature, + mock_get_with_acl, + ): + """Test handling of LLM execution timeouts.""" + # Setup mocks + mock_sketch = mock.MagicMock() + mock_sketch.has_permission.return_value = True + mock_sketch.id = 1 + mock_get_with_acl.return_value = mock_sketch + + mock_feature = mock.MagicMock() + mock_feature.NAME = "test_feature" + mock_feature.generate_prompt.return_value = "test prompt" + mock_get_feature.return_value = mock_feature + + mock_get_validated_indices.return_value = (["index1"], [1]) + + process_instance = mock.MagicMock() + process_instance.is_alive.return_value = True + mock_process.return_value = process_instance + + self.login() + response = self.client.post( + self.resource_url, + data=json.dumps({"feature": "test_feature", "filter": {}}), + content_type="application/json", + ) + + self.assertEqual(response.status_code, HTTP_STATUS_CODE_BAD_REQUEST) + response_data = json.loads(response.get_data(as_text=True)) + self.assertIn("LLM call timed out", response_data["message"]) + + process_instance.terminate.assert_called_once() From 76146e2893f6bdb5d1f46d52ff52607f9652bb50 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 11 Mar 2025 13:40:06 +0100 Subject: [PATCH 21/22] add timeout option to snackbar --- timesketch/frontend-ng/src/mixins/snackBar.js | 59 +++++++++---------- timesketch/frontend-v3/src/mixins.js | 52 ++++++++-------- 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/timesketch/frontend-ng/src/mixins/snackBar.js b/timesketch/frontend-ng/src/mixins/snackBar.js index 97ac9bc918..211d2b0472 100644 --- a/timesketch/frontend-ng/src/mixins/snackBar.js +++ b/timesketch/frontend-ng/src/mixins/snackBar.js @@ -1,13 +1,9 @@ - /* Copyright 2022 Google Inc. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,41 +11,42 @@ See the License for the specific language governing permissions and limitations under the License. */ import Vue from 'vue' - const defaultTimeout = 5000 -const defaultSnackBar = { - "message": "", - "color": "info", - "timeout": defaultTimeout -} -// These methids will be available to all components without any further imports. +// These methods will be available to all components without any further imports. Vue.mixin({ methods: { - successSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "success" + successSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "success", + timeout: timeout + } + this.$store.dispatch('setSnackBar', snackbar) + }, + errorSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "error", + timeout: timeout + } this.$store.dispatch('setSnackBar', snackbar) }, - errorSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "error" + warningSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "warning", + timeout: timeout + } this.$store.dispatch('setSnackBar', snackbar) }, - warningSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "warning" - this.$store.dispatch('setSnackBar', snackbar) + infoSnackBar(message, timeout = 2000) { + const snackbar = { + message: message, + color: "info", + timeout: timeout + } + this.$store.dispatch('setSnackBar', snackbar) }, - infoSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "info" - snackbar.timeout = 2000 - this.$store.dispatch('setSnackBar', snackbar) - }, } }) diff --git a/timesketch/frontend-v3/src/mixins.js b/timesketch/frontend-v3/src/mixins.js index f1f0d31328..9fb8a074f0 100644 --- a/timesketch/frontend-v3/src/mixins.js +++ b/timesketch/frontend-v3/src/mixins.js @@ -1,39 +1,39 @@ - const defaultTimeout = 5000 -const defaultSnackBar = { - "message": "", - "color": "info", - "timeout": defaultTimeout -} -// These methods will be available to all components without any further imports. export const snackBarMixin = { methods: { - successSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "success" + successSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "success", + timeout: timeout + } console.log('success snack bar', message) this.appStore.setSnackBar(snackbar) }, - errorSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "error" + errorSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "error", + timeout: timeout + } this.appStore.setSnackBar(snackbar) }, - warningSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "warning" - this.appStore.setSnackBar(snackbar) + warningSnackBar(message, timeout = defaultTimeout) { + const snackbar = { + message: message, + color: "warning", + timeout: timeout + } + this.appStore.setSnackBar(snackbar) }, - infoSnackBar(message) { - let snackbar = defaultSnackBar - snackbar.message = message - snackbar.color = "info" - snackbar.timeout = 2000 - this.appStore.setSnackBar(snackbar) + infoSnackBar(message, timeout = 2000) { + const snackbar = { + message: message, + color: "info", + timeout: timeout + } + this.appStore.setSnackBar(snackbar) }, } } From 3bc0e2c1e3fa4401de7499b0bd204b3f063632bd Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 11 Mar 2025 13:52:14 +0100 Subject: [PATCH 22/22] add back doc-string in v3 snackbar --- timesketch/frontend-v3/src/mixins.js | 1 + 1 file changed, 1 insertion(+) diff --git a/timesketch/frontend-v3/src/mixins.js b/timesketch/frontend-v3/src/mixins.js index 9fb8a074f0..5a1c8b53ca 100644 --- a/timesketch/frontend-v3/src/mixins.js +++ b/timesketch/frontend-v3/src/mixins.js @@ -1,5 +1,6 @@ const defaultTimeout = 5000 +// These methods will be available to all components without any further imports. export const snackBarMixin = { methods: { successSnackBar(message, timeout = defaultTimeout) {