Skip to content

Commit 7757e69

Browse files
[FSTORE-1656][4.2] Fetching feature vectors in the predictor file of a deployment fails with event loop is already running error (logicalclocks#498)
* add thredpoolexecutor for running loops * removing prints and reverting get_running_loop changes * adding back set event loop * moving creation of thread pool to the init function * adding connection options for timeout * making the number of threads in the thread pool configurable * creating tasks using a event loop * spawning threads instead of a thread pool executor to not not block calls of get_feature_vector from differnt threads * creating a single thread to run all async tasks * updating comments * moving the connection pool variable into the Async thread --------- Co-authored-by: DhananjayMukhedkar <dhananjay1098@gmail.com>
1 parent 694fc17 commit 7757e69

File tree

2 files changed

+236
-40
lines changed

2 files changed

+236
-40
lines changed

python/hopsworks_common/util.py

+190
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
import inspect
2021
import itertools
2122
import json
2223
import os
24+
import queue
2325
import re
2426
import shutil
2527
import sys
@@ -713,3 +715,191 @@ def feature_view_to_json(obj):
713715

714716
return humps.camelize(json.loads(obj.json()))
715717
return None
718+
719+
720+
class AsyncTask:
721+
"""
722+
Generic class to represent an async task.
723+
724+
Args:
725+
func (Callable): The function to run asynchronously.
726+
requires_connection_pool (bool): Whether the task requires a connection pool.
727+
**kwargs: Key word arguments to be passed to the functions.
728+
729+
Properties:
730+
result (Any): The result of the async task.
731+
event (threading.Event): The event that will be set when the async task is finished.
732+
"""
733+
734+
def __init__(
735+
self,
736+
task_function: Callable,
737+
task_args: Tuple = (),
738+
requires_connection_pool=None,
739+
**kwargs,
740+
):
741+
self.task_function = task_function
742+
self.task_args = task_args
743+
self.task_kwargs = kwargs
744+
self._event: threading.Event = threading.Event()
745+
self._result: Any = None
746+
self._requires_connection_pool = requires_connection_pool
747+
748+
@property
749+
def result(self) -> Any:
750+
"""
751+
The result of the async task.
752+
"""
753+
return self._result
754+
755+
@result.setter
756+
def result(self, value) -> None:
757+
self._result = value
758+
759+
@property
760+
def event(self) -> threading.Event:
761+
"""
762+
The event that will be set when the async task is finished.
763+
"""
764+
return self._event
765+
766+
@event.setter
767+
def event(self, value) -> None:
768+
self._event = value
769+
770+
@property
771+
def requires_connection_pool(self) -> bool:
772+
"""
773+
Whether the task requires a connection pool.
774+
"""
775+
return self._requires_connection_pool
776+
777+
778+
class AsyncTaskThread(threading.Thread):
779+
"""
780+
Generic thread class that can be used to run async tasks in a separate thread.
781+
The thread will create its own event loop and run submitted tasks in that loop.
782+
783+
The thread also store and fetches a connection pool that can be used by the async tasks.
784+
785+
# Args:
786+
connection_pool_initializer (Callable): A function that initializes a connection pool.
787+
connection_pool_params (Tuple): The parameters to pass to the connection pool initializer.
788+
*thread_args: Arguments to be passed to the thread.
789+
**thread_kwargs: Key word arguments to be passed to the thread.
790+
791+
# Properties:
792+
event_loop (asyncio.AbstractEventLoop): The event loop used by the thread.
793+
task_queue (queue.Queue[AsyncTask]): The queue used to submit tasks to the thread.
794+
connection_pool: The connection pool used
795+
"""
796+
797+
def __init__(
798+
self,
799+
connection_pool_initializer: Callable = None,
800+
connection_pool_params: Tuple = (),
801+
*thread_args,
802+
**thread_kwargs,
803+
):
804+
super().__init__(*thread_args, **thread_kwargs)
805+
self._task_queue: queue.Queue[AsyncTask] = queue.Queue()
806+
self._event_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
807+
self.stop_event = threading.Event()
808+
self._connection_pool_initializer: Callable = connection_pool_initializer
809+
self._connection_pool_params: Tuple = connection_pool_params
810+
self._connection_pool = None
811+
self.daemon = True # Setting the thread as a daemon thread by default, so it will be terminated when the main thread is terminated.
812+
813+
async def execute_task(self):
814+
"""
815+
Execute the async tasks for the queue.
816+
"""
817+
asyncio.set_event_loop(self._event_loop)
818+
819+
if self._connection_pool_initializer:
820+
self._connection_pool = await self._connection_pool_initializer(
821+
*self._connection_pool_params
822+
)
823+
824+
while not self.stop_event.is_set():
825+
# Fetch a task from the queue.
826+
task = self.task_queue.get()
827+
# Run the task in the event loop and get the result
828+
try:
829+
# Check if the task requires a connection pool and pass it to the function if it does.
830+
if task.requires_connection_pool:
831+
task.result = await task.task_function(
832+
*task.task_args,
833+
**task.task_kwargs,
834+
connection_pool=self.connection_pool,
835+
)
836+
else:
837+
task.result = await task.task_function(
838+
*task.task_args, **task.task_kwargs
839+
)
840+
841+
# Unblock the task, so the submit function can return the result.
842+
task.event.set()
843+
except Exception as e:
844+
task.result = e
845+
task.event.set()
846+
raise e
847+
848+
def stop(self):
849+
"""
850+
Stop the thread and close the event loop.
851+
"""
852+
self.stop_event.set()
853+
self._event_loop.stop()
854+
self._event_loop.close()
855+
856+
def run(self):
857+
"""
858+
Execute the async tasks for the queue.
859+
"""
860+
asyncio.set_event_loop(self._event_loop)
861+
self._event_loop.create_task(self.execute_task())
862+
try:
863+
self._event_loop.run_forever()
864+
except Exception as e:
865+
self._event_loop.stop()
866+
self._event_loop.close()
867+
raise e
868+
finally:
869+
self._event_loop.close()
870+
871+
def submit(self, task: AsyncTask):
872+
"""
873+
Submit a async task to the thread and block until the execution of the function is completed.
874+
"""
875+
# Submit a task to the queue.
876+
self.task_queue.put(task)
877+
# Block the execution until the task is finished.
878+
task.event.wait()
879+
880+
if isinstance(task.result, Exception):
881+
raise task.result
882+
else:
883+
# Return the result of the task.
884+
return task.result
885+
886+
@property
887+
def event_loop(self) -> asyncio.AbstractEventLoop:
888+
"""
889+
The event loop used by the thread.
890+
"""
891+
return self._event_loop
892+
893+
@property
894+
def task_queue(self) -> queue.Queue[AsyncTask]:
895+
"""
896+
The queue used to submit tasks to the thread.
897+
"""
898+
return self._task_queue
899+
900+
@property
901+
def connection_pool(self):
902+
"""
903+
The connection pool used by the thread.
904+
"""
905+
return self._connection_pool

python/hsfs/core/online_store_sql_engine.py

+46-40
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
2323

2424
from hopsworks_common.core import variable_api
25+
from hopsworks_common.util import AsyncTask, AsyncTaskThread
2526
from hsfs import util
2627
from hsfs.core import (
2728
feature_view_api,
@@ -73,7 +74,6 @@ def __init__(
7374
self._prefix_by_serving_index = None
7475
self._pkname_by_serving_index = None
7576
self._serving_key_by_serving_index: Dict[str, ServingKey] = {}
76-
self._connection_pool = None
7777
self._serving_keys: Set[ServingKey] = set(serving_keys or [])
7878

7979
self._prepared_statements: Dict[str, List[ServingPreparedStatement]] = {}
@@ -89,6 +89,14 @@ def __init__(
8989
self._hostname = None
9090
self._connection_options = None
9191

92+
self._async_task_thread = None
93+
94+
def __del__(self):
95+
# Safely stop the async task thread.
96+
# The connection pool will be closed during garbage collection by aiomysql.
97+
if self._async_task_thread.is_alive():
98+
self._async_task_thread.stop()
99+
92100
def fetch_prepared_statements(
93101
self,
94102
entity: Union[feature_view.FeatureView, training_dataset.TrainingDataset],
@@ -245,13 +253,15 @@ def init_async_mysql_connection(self, options=None):
245253
else None
246254
)
247255

248-
if util.is_runtime_notebook():
249-
_logger.debug("Running in Jupyter notebook, applying nest_asyncio")
250-
import nest_asyncio
251-
252-
nest_asyncio.apply()
253-
else:
254-
_logger.debug("Running in python script. Not applying nest_asyncio")
256+
if not self._async_task_thread:
257+
# Create the async event thread if it is not already running and start it.
258+
self._async_task_thread = AsyncTaskThread(
259+
connection_pool_initializer=self._get_connection_pool,
260+
connection_pool_params=(
261+
len(self._prepared_statements[self.SINGLE_VECTOR_KEY]),
262+
),
263+
)
264+
self._async_task_thread.start()
255265

256266
def get_single_feature_vector(self, entry: Dict[str, Any]) -> Dict[str, Any]:
257267
"""Retrieve single vector with parallel queries using aiomysql engine."""
@@ -325,9 +335,15 @@ def _single_vector_result(
325335
_logger.debug(
326336
f"Executing prepared statements for serving vector with entries: {bind_entries}"
327337
)
328-
loop = self._get_or_create_event_loop()
329-
results_dict = loop.run_until_complete(
330-
self._execute_prep_statements(prepared_statement_execution, bind_entries)
338+
results_dict = self._async_task_thread.submit(
339+
AsyncTask(
340+
task_function=self._execute_prep_statements,
341+
task_args=(
342+
prepared_statement_execution,
343+
bind_entries,
344+
),
345+
requires_connection_pool=True,
346+
)
331347
)
332348
_logger.debug(f"Retrieved feature vectors: {results_dict}")
333349
_logger.debug("Constructing serving vector from results")
@@ -393,9 +409,12 @@ def _batch_vector_results(
393409
f"Executing prepared statements for batch vector with entries: {entry_values}"
394410
)
395411
# run all the prepared statements in parallel using aiomysql engine
396-
loop = self._get_or_create_event_loop()
397-
parallel_results = loop.run_until_complete(
398-
self._execute_prep_statements(prepared_stmts_to_execute, entry_values)
412+
parallel_results = self._async_task_thread.submit(
413+
AsyncTask(
414+
task_function=self._execute_prep_statements,
415+
task_args=(prepared_stmts_to_execute, entry_values),
416+
requires_connection_pool=True,
417+
)
399418
)
400419

401420
_logger.debug(f"Retrieved feature vectors: {parallel_results}, stitching them.")
@@ -441,20 +460,6 @@ def _batch_vector_results(
441460
)
442461
return batch_results, serving_keys_all_fg
443462

444-
def _get_or_create_event_loop(self):
445-
try:
446-
_logger.debug("Acquiring or starting event loop for async engine.")
447-
loop = asyncio.get_event_loop()
448-
asyncio.set_event_loop(loop)
449-
except RuntimeError as ex:
450-
if "There is no current event loop in thread" in str(ex):
451-
_logger.debug(
452-
"No existing running event loop. Creating new event loop."
453-
)
454-
loop = asyncio.new_event_loop()
455-
asyncio.set_event_loop(loop)
456-
return loop
457-
458463
def refresh_mysql_connection(self):
459464
_logger.debug("Refreshing MySQL connection.")
460465
try:
@@ -547,22 +552,24 @@ def get_prepared_statement_labels(
547552
]
548553

549554
async def _get_connection_pool(self, default_min_size: int) -> None:
550-
self._connection_pool = await util_sql.create_async_engine(
555+
connection_pool = await util_sql.create_async_engine(
551556
self._online_connector,
552557
self._external,
553558
default_min_size,
554559
options=self._connection_options,
555560
hostname=self._hostname,
556561
)
562+
return connection_pool
557563

558-
async def _query_async_sql(self, stmt, bind_params):
564+
async def _query_async_sql(
565+
self,
566+
stmt,
567+
bind_params,
568+
connection_pool: aiomysql.utils._ConnectionContextManager,
569+
):
559570
"""Query prepared statement together with bind params using aiomysql connection pool"""
560571
# create connection pool
561-
await self._get_connection_pool(
562-
len(self._prepared_statements[self.SINGLE_VECTOR_KEY])
563-
)
564-
565-
async with self._connection_pool.acquire() as conn:
572+
async with connection_pool.acquire() as conn:
566573
# Execute the prepared statement
567574
_logger.debug(
568575
f"Executing prepared statement: {stmt} with bind params: {bind_params}"
@@ -580,6 +587,7 @@ async def _execute_prep_statements(
580587
self,
581588
prepared_statements: Dict[int, str],
582589
entries: Union[List[Dict[str, Any]], Dict[str, Any]],
590+
connection_pool: aiomysql.utils._ConnectionContextManager, # The connection pool required is passed as a parameter from the AsyncTaskThread.
583591
):
584592
"""Iterate over prepared statements to create async tasks
585593
and gather all tasks results for a given list of entries."""
@@ -595,7 +603,9 @@ async def _execute_prep_statements(
595603
try:
596604
tasks = [
597605
asyncio.create_task(
598-
self._query_async_sql(prepared_statements[key], entries[key]),
606+
self._query_async_sql(
607+
prepared_statements[key], entries[key], connection_pool
608+
),
599609
name="query_prep_statement_key" + str(key),
600610
)
601611
for key in prepared_statements
@@ -749,7 +759,3 @@ def connection_options(self) -> Dict[str, Any]:
749759
@property
750760
def online_connector(self) -> storage_connector.StorageConnector:
751761
return self._online_connector
752-
753-
@property
754-
def connection_pool(self) -> aiomysql.utils._ConnectionContextManager:
755-
return self._connection_pool

0 commit comments

Comments
 (0)