22
22
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple , Union
23
23
24
24
from hopsworks_common .core import variable_api
25
+ from hopsworks_common .util import AsyncTask , AsyncTaskThread
25
26
from hsfs import util
26
27
from hsfs .core import (
27
28
feature_view_api ,
@@ -73,7 +74,6 @@ def __init__(
73
74
self ._prefix_by_serving_index = None
74
75
self ._pkname_by_serving_index = None
75
76
self ._serving_key_by_serving_index : Dict [str , ServingKey ] = {}
76
- self ._connection_pool = None
77
77
self ._serving_keys : Set [ServingKey ] = set (serving_keys or [])
78
78
79
79
self ._prepared_statements : Dict [str , List [ServingPreparedStatement ]] = {}
@@ -89,6 +89,14 @@ def __init__(
89
89
self ._hostname = None
90
90
self ._connection_options = None
91
91
92
+ self ._async_task_thread = None
93
+
94
+ def __del__ (self ):
95
+ # Safely stop the async task thread.
96
+ # The connection pool will be closed during garbage collection by aiomysql.
97
+ if self ._async_task_thread .is_alive ():
98
+ self ._async_task_thread .stop ()
99
+
92
100
def fetch_prepared_statements (
93
101
self ,
94
102
entity : Union [feature_view .FeatureView , training_dataset .TrainingDataset ],
@@ -245,13 +253,15 @@ def init_async_mysql_connection(self, options=None):
245
253
else None
246
254
)
247
255
248
- if util .is_runtime_notebook ():
249
- _logger .debug ("Running in Jupyter notebook, applying nest_asyncio" )
250
- import nest_asyncio
251
-
252
- nest_asyncio .apply ()
253
- else :
254
- _logger .debug ("Running in python script. Not applying nest_asyncio" )
256
+ if not self ._async_task_thread :
257
+ # Create the async event thread if it is not already running and start it.
258
+ self ._async_task_thread = AsyncTaskThread (
259
+ connection_pool_initializer = self ._get_connection_pool ,
260
+ connection_pool_params = (
261
+ len (self ._prepared_statements [self .SINGLE_VECTOR_KEY ]),
262
+ ),
263
+ )
264
+ self ._async_task_thread .start ()
255
265
256
266
def get_single_feature_vector (self , entry : Dict [str , Any ]) -> Dict [str , Any ]:
257
267
"""Retrieve single vector with parallel queries using aiomysql engine."""
@@ -325,9 +335,15 @@ def _single_vector_result(
325
335
_logger .debug (
326
336
f"Executing prepared statements for serving vector with entries: { bind_entries } "
327
337
)
328
- loop = self ._get_or_create_event_loop ()
329
- results_dict = loop .run_until_complete (
330
- self ._execute_prep_statements (prepared_statement_execution , bind_entries )
338
+ results_dict = self ._async_task_thread .submit (
339
+ AsyncTask (
340
+ task_function = self ._execute_prep_statements ,
341
+ task_args = (
342
+ prepared_statement_execution ,
343
+ bind_entries ,
344
+ ),
345
+ requires_connection_pool = True ,
346
+ )
331
347
)
332
348
_logger .debug (f"Retrieved feature vectors: { results_dict } " )
333
349
_logger .debug ("Constructing serving vector from results" )
@@ -393,9 +409,12 @@ def _batch_vector_results(
393
409
f"Executing prepared statements for batch vector with entries: { entry_values } "
394
410
)
395
411
# run all the prepared statements in parallel using aiomysql engine
396
- loop = self ._get_or_create_event_loop ()
397
- parallel_results = loop .run_until_complete (
398
- self ._execute_prep_statements (prepared_stmts_to_execute , entry_values )
412
+ parallel_results = self ._async_task_thread .submit (
413
+ AsyncTask (
414
+ task_function = self ._execute_prep_statements ,
415
+ task_args = (prepared_stmts_to_execute , entry_values ),
416
+ requires_connection_pool = True ,
417
+ )
399
418
)
400
419
401
420
_logger .debug (f"Retrieved feature vectors: { parallel_results } , stitching them." )
@@ -441,20 +460,6 @@ def _batch_vector_results(
441
460
)
442
461
return batch_results , serving_keys_all_fg
443
462
444
- def _get_or_create_event_loop (self ):
445
- try :
446
- _logger .debug ("Acquiring or starting event loop for async engine." )
447
- loop = asyncio .get_event_loop ()
448
- asyncio .set_event_loop (loop )
449
- except RuntimeError as ex :
450
- if "There is no current event loop in thread" in str (ex ):
451
- _logger .debug (
452
- "No existing running event loop. Creating new event loop."
453
- )
454
- loop = asyncio .new_event_loop ()
455
- asyncio .set_event_loop (loop )
456
- return loop
457
-
458
463
def refresh_mysql_connection (self ):
459
464
_logger .debug ("Refreshing MySQL connection." )
460
465
try :
@@ -547,22 +552,24 @@ def get_prepared_statement_labels(
547
552
]
548
553
549
554
async def _get_connection_pool (self , default_min_size : int ) -> None :
550
- self . _connection_pool = await util_sql .create_async_engine (
555
+ connection_pool = await util_sql .create_async_engine (
551
556
self ._online_connector ,
552
557
self ._external ,
553
558
default_min_size ,
554
559
options = self ._connection_options ,
555
560
hostname = self ._hostname ,
556
561
)
562
+ return connection_pool
557
563
558
- async def _query_async_sql (self , stmt , bind_params ):
564
+ async def _query_async_sql (
565
+ self ,
566
+ stmt ,
567
+ bind_params ,
568
+ connection_pool : aiomysql .utils ._ConnectionContextManager ,
569
+ ):
559
570
"""Query prepared statement together with bind params using aiomysql connection pool"""
560
571
# create connection pool
561
- await self ._get_connection_pool (
562
- len (self ._prepared_statements [self .SINGLE_VECTOR_KEY ])
563
- )
564
-
565
- async with self ._connection_pool .acquire () as conn :
572
+ async with connection_pool .acquire () as conn :
566
573
# Execute the prepared statement
567
574
_logger .debug (
568
575
f"Executing prepared statement: { stmt } with bind params: { bind_params } "
@@ -580,6 +587,7 @@ async def _execute_prep_statements(
580
587
self ,
581
588
prepared_statements : Dict [int , str ],
582
589
entries : Union [List [Dict [str , Any ]], Dict [str , Any ]],
590
+ connection_pool : aiomysql .utils ._ConnectionContextManager , # The connection pool required is passed as a parameter from the AsyncTaskThread.
583
591
):
584
592
"""Iterate over prepared statements to create async tasks
585
593
and gather all tasks results for a given list of entries."""
@@ -595,7 +603,9 @@ async def _execute_prep_statements(
595
603
try :
596
604
tasks = [
597
605
asyncio .create_task (
598
- self ._query_async_sql (prepared_statements [key ], entries [key ]),
606
+ self ._query_async_sql (
607
+ prepared_statements [key ], entries [key ], connection_pool
608
+ ),
599
609
name = "query_prep_statement_key" + str (key ),
600
610
)
601
611
for key in prepared_statements
@@ -749,7 +759,3 @@ def connection_options(self) -> Dict[str, Any]:
749
759
@property
750
760
def online_connector (self ) -> storage_connector .StorageConnector :
751
761
return self ._online_connector
752
-
753
- @property
754
- def connection_pool (self ) -> aiomysql .utils ._ConnectionContextManager :
755
- return self ._connection_pool
0 commit comments