Skip to content

Commit efdb2ff

Browse files
authored
Introduce LLMResource API method, tests, and add it as a method for the frontend (#3310)
* Introduce LLMResource API method, tests, and add it as a method for the frontend
1 parent cb95144 commit efdb2ff

File tree

4 files changed

+578
-1
lines changed

4 files changed

+578
-1
lines changed

timesketch/api/v1/resources/llm.py

+357
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright 2025 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Timesketch API endpoint for interacting with LLM features."""
15+
import logging
16+
import multiprocessing
17+
import multiprocessing.managers
18+
import time
19+
import prometheus_client
20+
from flask import request, abort, jsonify, Response
21+
from flask_login import login_required, current_user
22+
from flask_restful import Resource
23+
from timesketch.api.v1 import resources
24+
from timesketch.lib import definitions, utils
25+
from timesketch.lib.definitions import METRICS_NAMESPACE
26+
from timesketch.lib.llms.providers import manager as llm_manager
27+
from timesketch.lib.llms.features import manager as feature_manager
28+
from timesketch.models.sketch import Sketch
29+
30+
logger = logging.getLogger("timesketch.api.llm")
31+
32+
33+
class LLMResource(resources.ResourceMixin, Resource):
34+
"""Resource to interact with LLMs.
35+
36+
This class provides an API endpoint for accessing and utilizing Large Language
37+
Model features within Timesketch. It handles request validation, processing,
38+
and response handling, while also monitoring performance metrics.
39+
"""
40+
41+
METRICS = {
42+
"llm_requests_total": prometheus_client.Counter(
43+
"llm_requests_total",
44+
"Total number of LLM requests received",
45+
["sketch_id", "feature"],
46+
namespace=METRICS_NAMESPACE,
47+
),
48+
"llm_errors_total": prometheus_client.Counter(
49+
"llm_errors_total",
50+
"Total number of errors during LLM processing",
51+
["sketch_id", "feature", "error_type"],
52+
namespace=METRICS_NAMESPACE,
53+
),
54+
"llm_duration_seconds": prometheus_client.Summary(
55+
"llm_duration_seconds",
56+
"Time taken to process an LLM request (in seconds)",
57+
["sketch_id", "feature"],
58+
namespace=METRICS_NAMESPACE,
59+
),
60+
}
61+
# TODO(itsmvd): Make this configurable
62+
_LLM_TIMEOUT_WAIT_SECONDS = 30
63+
64+
@login_required
65+
def post(self, sketch_id: int) -> Response:
66+
"""Handles POST requests to the resource.
67+
68+
Processes LLM requests, validates inputs, generates prompts,
69+
executes LLM calls, and returns the processed results.
70+
71+
Args:
72+
sketch_id: The ID of the sketch to process.
73+
74+
Returns:
75+
A Flask JSON response containing the processed LLM result.
76+
77+
Raises:
78+
HTTP exceptions for various error conditions.
79+
"""
80+
start_time = time.time()
81+
sketch = self._validate_sketch(sketch_id)
82+
form = self._validate_request_data()
83+
feature = self._get_feature(form.get("feature"))
84+
self._increment_request_metric(sketch_id, feature.NAME)
85+
timeline_ids = self._validate_indices(sketch, form.get("filter", {}))
86+
prompt = self._generate_prompt(feature, sketch, form, timeline_ids)
87+
response = self._execute_llm_call(feature, prompt, sketch_id)
88+
result = self._process_llm_response(
89+
feature, response, sketch, form, timeline_ids
90+
)
91+
self._record_duration(sketch_id, feature.NAME, start_time)
92+
return jsonify(result)
93+
94+
def _validate_sketch(self, sketch_id: int) -> Sketch:
95+
"""Validates sketch existence and user permissions.
96+
97+
Args:
98+
sketch_id: The ID of the sketch to validate.
99+
100+
Returns:
101+
The validated Sketch object.
102+
103+
Raises:
104+
HTTP 404: If the sketch doesn't exist.
105+
HTTP 403: If the user doesn't have read access to the sketch.
106+
"""
107+
sketch = Sketch.get_with_acl(sketch_id)
108+
if not sketch:
109+
abort(
110+
definitions.HTTP_STATUS_CODE_NOT_FOUND, "No sketch found with this ID."
111+
)
112+
if not sketch.has_permission(current_user, "read"):
113+
abort(
114+
definitions.HTTP_STATUS_CODE_FORBIDDEN,
115+
"User does not have read access to the sketch.",
116+
)
117+
return sketch
118+
119+
def _validate_request_data(self) -> dict:
120+
"""Validates the presence of request JSON data.
121+
122+
Returns:
123+
The validated request data as a dictionary.
124+
125+
Raises:
126+
HTTP 400: If no JSON data is provided in the request.
127+
"""
128+
form = request.json
129+
if not form:
130+
abort(
131+
definitions.HTTP_STATUS_CODE_BAD_REQUEST,
132+
"The POST request requires data",
133+
)
134+
return form
135+
136+
def _get_feature(self, feature_name: str) -> feature_manager.LLMFeatureInterface:
137+
"""Retrieves and validates the requested LLM feature.
138+
139+
Args:
140+
feature_name: The name of the LLM feature to retrieve.
141+
142+
Returns:
143+
An instance of the requested LLM feature.
144+
145+
Raises:
146+
HTTP 400: If feature_name is not provided or is invalid.
147+
"""
148+
if not feature_name:
149+
abort(
150+
definitions.HTTP_STATUS_CODE_BAD_REQUEST,
151+
"The 'feature' parameter is required.",
152+
)
153+
try:
154+
return feature_manager.FeatureManager.get_feature_instance(feature_name)
155+
except KeyError:
156+
abort(
157+
definitions.HTTP_STATUS_CODE_BAD_REQUEST,
158+
f"Invalid LLM feature: {feature_name}",
159+
)
160+
161+
def _validate_indices(self, sketch: Sketch, query_filter: dict) -> list:
162+
"""Extracts and validates timeline IDs from the query filter for a sketch.
163+
164+
Args:
165+
sketch: The Sketch object to validate indices for.
166+
query_filter: A dictionary containing filter parameters.
167+
168+
Returns:
169+
A list of validated timeline IDs.
170+
171+
Raises:
172+
HTTP 400: If no valid search indices are found.
173+
"""
174+
all_indices = list({t.searchindex.index_name for t in sketch.timelines})
175+
indices = query_filter.get("indices", all_indices)
176+
if "_all" in indices:
177+
indices = all_indices
178+
indices, timeline_ids = utils.get_validated_indices(indices, sketch)
179+
if not indices:
180+
abort(
181+
definitions.HTTP_STATUS_CODE_BAD_REQUEST,
182+
"No valid search indices were found.",
183+
)
184+
return timeline_ids
185+
186+
def _generate_prompt(
187+
self,
188+
feature: feature_manager.LLMFeatureInterface,
189+
sketch: Sketch,
190+
form: dict,
191+
timeline_ids: list,
192+
) -> str:
193+
"""Generates the LLM prompt based on the feature and request data.
194+
195+
Args:
196+
feature: The LLM feature instance to use.
197+
sketch: The Sketch object.
198+
form: The request form data.
199+
timeline_ids: A list of validated timeline IDs.
200+
201+
Returns:
202+
The generated prompt string for the LLM.
203+
204+
Raises:
205+
HTTP 400: If prompt generation fails.
206+
"""
207+
try:
208+
return feature.generate_prompt(
209+
sketch, form=form, datastore=self.datastore, timeline_ids=timeline_ids
210+
)
211+
except ValueError as e:
212+
abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, str(e))
213+
214+
def _execute_llm_call(
215+
self, feature: feature_manager.LLMFeatureInterface, prompt: str, sketch_id: int
216+
) -> dict:
217+
"""Executes the LLM call with a timeout using multiprocessing.
218+
219+
Args:
220+
feature: The LLM feature instance to use.
221+
prompt: The generated prompt to send to the LLM.
222+
sketch_id: The ID of the sketch being processed.
223+
224+
Returns:
225+
The LLM response as a dictionary.
226+
227+
Raises:
228+
HTTP 400: If the LLM call times out.
229+
HTTP 500: If an error occurs during LLM processing.
230+
"""
231+
with multiprocessing.Manager() as manager:
232+
shared_response = manager.dict()
233+
process = multiprocessing.Process(
234+
target=self._get_content_with_timeout,
235+
args=(feature, prompt, shared_response),
236+
)
237+
process.start()
238+
process.join(timeout=self._LLM_TIMEOUT_WAIT_SECONDS)
239+
if process.is_alive():
240+
logger.warning(
241+
"LLM call timed out after %d seconds.",
242+
self._LLM_TIMEOUT_WAIT_SECONDS,
243+
)
244+
process.terminate()
245+
process.join()
246+
self.METRICS["llm_errors_total"].labels(
247+
sketch_id=str(sketch_id), feature=feature.NAME, error_type="timeout"
248+
).inc()
249+
abort(
250+
definitions.HTTP_STATUS_CODE_BAD_REQUEST,
251+
"LLM call timed out, please try again. "
252+
"If this issue persists, contact your administrator.",
253+
)
254+
response = dict(shared_response)
255+
if "error" in response:
256+
self.METRICS["llm_errors_total"].labels(
257+
sketch_id=str(sketch_id),
258+
feature=feature.NAME,
259+
error_type="llm_api_error",
260+
).inc()
261+
abort(
262+
definitions.HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
263+
f"Error during LLM processing: {response['error']}",
264+
)
265+
return response["response"]
266+
267+
def _process_llm_response(
268+
self,
269+
feature: feature_manager.LLMFeatureInterface,
270+
response: dict,
271+
sketch: Sketch,
272+
form: dict,
273+
timeline_ids: list,
274+
) -> dict:
275+
"""Processes the LLM response into the final result.
276+
277+
Args:
278+
feature: The LLM feature instance used.
279+
response: The raw LLM response.
280+
sketch: The Sketch object.
281+
form: The request form data.
282+
timeline_ids: A list of validated timeline IDs.
283+
284+
Returns:
285+
The processed LLM response as a dictionary.
286+
287+
Raises:
288+
HTTP 400: If response processing fails.
289+
"""
290+
try:
291+
return feature.process_response(
292+
llm_response=response,
293+
form=form,
294+
sketch_id=sketch.id,
295+
datastore=self.datastore,
296+
sketch=sketch,
297+
timeline_ids=timeline_ids,
298+
)
299+
except ValueError as e:
300+
self.METRICS["llm_errors_total"].labels(
301+
sketch_id=str(sketch.id),
302+
feature=feature.NAME,
303+
error_type="response_processing",
304+
).inc()
305+
abort(definitions.HTTP_STATUS_CODE_BAD_REQUEST, str(e))
306+
307+
def _increment_request_metric(self, sketch_id: int, feature_name: str) -> None:
308+
"""Increments the request counter metric.
309+
310+
Args:
311+
sketch_id: The ID of the sketch being processed.
312+
feature_name: The name of the LLM feature being used.
313+
"""
314+
self.METRICS["llm_requests_total"].labels(
315+
sketch_id=str(sketch_id), feature=feature_name
316+
).inc()
317+
318+
def _record_duration(
319+
self, sketch_id: int, feature_name: str, start_time: float
320+
) -> None:
321+
"""Records the duration of the request.
322+
323+
Args:
324+
sketch_id: The ID of the sketch being processed.
325+
feature_name: The name of the LLM feature being used.
326+
start_time: The timestamp when the request started.
327+
"""
328+
duration = time.time() - start_time
329+
self.METRICS["llm_duration_seconds"].labels(
330+
sketch_id=str(sketch_id), feature=feature_name
331+
).observe(duration)
332+
333+
def _get_content_with_timeout(
334+
self,
335+
feature: feature_manager.LLMFeatureInterface,
336+
prompt: str,
337+
shared_response: multiprocessing.managers.DictProxy,
338+
) -> None:
339+
"""Send a prompt to the LLM and get a response within a process.
340+
341+
This method is executed in a separate process to allow for timeout control.
342+
343+
Args:
344+
feature: The LLM feature instance to use.
345+
prompt: The generated prompt to send to the LLM.
346+
shared_response: A managed dictionary to store the response or error.
347+
"""
348+
try:
349+
llm = llm_manager.LLMManager.create_provider(feature_name=feature.NAME)
350+
response_schema = (
351+
feature.RESPONSE_SCHEMA if hasattr(feature, "RESPONSE_SCHEMA") else None
352+
)
353+
response = llm.generate(prompt, response_schema=response_schema)
354+
shared_response.update({"response": response})
355+
except Exception as e: # pylint: disable=broad-except
356+
logger.error("Error in LLM call within process: %s", e, exc_info=True)
357+
shared_response.update({"error": str(e)})

0 commit comments

Comments
 (0)