Skip to content

Commit c5ac7ba

Browse files
Update opensearch.py
Attempt to solve following Bug Enable scrolling in search API endpoint google#3334
1 parent 16940e4 commit c5ac7ba

File tree

1 file changed

+74
-71
lines changed

1 file changed

+74
-71
lines changed

timesketch/lib/datastores/opensearch.py

Lines changed: 74 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -633,77 +633,80 @@ def search(
633633

634634
# pylint: disable=too-many-arguments
635635

636-
def search_stream(
637-
self,
638-
sketch_id: Optional[int] = None,
639-
query_string: Optional[str] = None,
640-
query_filter: Optional[Dict] = None,
641-
query_dsl: Optional[Dict] = None,
642-
indices: Optional[list] = None,
643-
return_fields: Optional[list] = None,
644-
enable_scroll: bool = True,
645-
timeline_ids: Optional[list] = None,
646-
):
647-
"""Search OpenSearch. This will take a query string from the UI
648-
together with a filter definition. Based on this it will execute the
649-
search request on OpenSearch and get result back.
650-
651-
Args :
652-
sketch_id: Integer of sketch primary key
653-
query_string: Query string
654-
query_filter: Dictionary containing filters to apply
655-
query_dsl: Dictionary containing OpenSearch DSL query
656-
indices: List of indices to query
657-
return_fields: List of fields to return
658-
enable_scroll: Boolean determining whether scrolling is enabled.
659-
timeline_ids: Optional list of IDs of Timeline objects that should
660-
be queried as part of the search.
661-
662-
Yields:
663-
Generator of event documents in JSON format
664-
"""
665-
# Make sure that the list of index names is uniq.
666-
indices = list(set(indices))
667-
668-
METRICS["search_requests"].labels(type="stream").inc()
669-
670-
if not query_filter.get("size"):
671-
query_filter["size"] = self.DEFAULT_STREAM_LIMIT
672-
673-
if not query_filter.get("terminate_after"):
674-
query_filter["terminate_after"] = self.DEFAULT_STREAM_LIMIT
675-
676-
result = self.search(
677-
sketch_id=sketch_id,
678-
query_string=query_string,
679-
query_dsl=query_dsl,
680-
query_filter=query_filter,
681-
indices=indices,
682-
return_fields=return_fields,
683-
enable_scroll=enable_scroll,
684-
timeline_ids=timeline_ids,
685-
)
686-
687-
if enable_scroll:
688-
scroll_id = result["_scroll_id"]
689-
scroll_size = result["hits"]["total"]
690-
else:
691-
scroll_id = None
692-
scroll_size = 0
693-
694-
# Elasticsearch version 7.x returns total hits as a dictionary.
695-
# TODO: Refactor when version 6.x has been deprecated.
696-
if isinstance(scroll_size, dict):
697-
scroll_size = scroll_size.get("value", 0)
698-
699-
yield from result["hits"]["hits"]
700-
701-
while scroll_size > 0:
702-
# pylint: disable=unexpected-keyword-arg
703-
result = self.client.scroll(scroll_id=scroll_id, scroll="5m")
704-
scroll_id = result["_scroll_id"]
705-
scroll_size = len(result["hits"]["hits"])
706-
yield from result["hits"]["hits"]
636+
ddef search_stream(
637+
self,
638+
sketch_id=None,
639+
query_string=None,
640+
query_filter=None,
641+
query_dsl=None,
642+
indices=None,
643+
return_fields=None,
644+
enable_scroll=True,
645+
timeline_ids=None,
646+
):
647+
"""Search OpenSearch. This will take a query string from the UI
648+
together with a filter definition. Based on this it will execute the
649+
search request on OpenSearch and get result back.
650+
651+
Args:
652+
sketch_id: Integer of sketch primary key
653+
query_string: Query string
654+
query_filter: Dictionary containing filters to apply
655+
query_dsl: Dictionary containing OpenSearch DSL query
656+
indices: List of indices to query
657+
return_fields: List of fields to return
658+
enable_scroll: Boolean determining whether scrolling is enabled.
659+
timeline_ids: Optional list of IDs of Timeline objects that should
660+
be queried as part of the search.
661+
662+
Returns:
663+
Generator of event documents in JSON format
664+
"""
665+
# Make sure that the list of index names is unique.
666+
indices = list(set(indices))
667+
668+
METRICS["search_requests"].labels(type="stream").inc()
669+
670+
if not query_filter.get("size"):
671+
query_filter["size"] = self.DEFAULT_STREAM_LIMIT
672+
673+
if not query_filter.get("terminate_after"):
674+
query_filter["terminate_after"] = self.DEFAULT_STREAM_LIMIT
675+
676+
# Perform the initial search
677+
result = self.search(
678+
sketch_id=sketch_id,
679+
query_string=query_string,
680+
query_dsl=query_dsl,
681+
query_filter=query_filter,
682+
indices=indices,
683+
return_fields=return_fields,
684+
enable_scroll=enable_scroll,
685+
timeline_ids=timeline_ids,
686+
)
687+
688+
# Check if scrolling is enabled and initialize scroll_id and scroll_size
689+
scroll_id = result.get("_scroll_id")
690+
scroll_size = result["hits"]["total"]
691+
692+
# Elasticsearch version 7.x returns total hits as a dictionary.
693+
if isinstance(scroll_size, dict):
694+
scroll_size = scroll_size.get("value", 0)
695+
696+
# Yield the initial batch of results
697+
for event in result["hits"]["hits"]:
698+
yield event
699+
700+
# Continue scrolling until there are no more results
701+
while scroll_size > 0:
702+
# Fetch the next batch of results using the scroll_id
703+
result = self.client.scroll(scroll_id=scroll_id, scroll="5m")
704+
scroll_id = result["_scroll_id"]
705+
scroll_size = len(result["hits"]["hits"])
706+
707+
# Yield the new batch of results
708+
for event in result["hits"]["hits"]:
709+
yield event
707710

708711
def get_filter_labels(self, sketch_id: int, indices: list):
709712
"""Aggregate labels for a sketch.

0 commit comments

Comments
 (0)