@@ -98,23 +98,12 @@ def _default_staging_path(self):
98
98
)
99
99
100
100
def __init__ (
101
- self ,
102
- root_dir : str ,
103
- environments_manager : Type [EnvironmentManager ],
104
- config = None ,
105
- ** kwargs ,
101
+ self , root_dir : str , environments_manager : Type [EnvironmentManager ], config = None , ** kwargs
106
102
):
107
103
super ().__init__ (config = config , ** kwargs )
108
104
self .root_dir = root_dir
109
105
self .environments_manager = environments_manager
110
106
111
- loop = asyncio .get_event_loop ()
112
- self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
113
-
114
- async def _get_dask_client (self ):
115
- """Creates and configures a Dask client."""
116
- return DaskClient (processes = False , asynchronous = True )
117
-
118
107
def create_job (self , model : CreateJob ) -> str :
119
108
"""Creates a new job record, may trigger execution of the job.
120
109
In case a task runner is actually handling execution of the jobs,
@@ -394,6 +383,12 @@ def get_local_output_path(
394
383
else :
395
384
return os .path .join (self .root_dir , self .output_directory , output_dir_name )
396
385
386
+ async def stop_extension (self ):
387
+ """
388
+ Placeholder method for a cleanup code to run when the server is stopping.
389
+ """
390
+ pass
391
+
397
392
398
393
class Scheduler (BaseScheduler ):
399
394
_db_session = None
@@ -427,6 +422,13 @@ def __init__(
427
422
if self .task_runner_class :
428
423
self .task_runner = self .task_runner_class (scheduler = self , config = config )
429
424
425
+ loop = asyncio .get_event_loop ()
426
+ self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
427
+
428
+ async def _get_dask_client (self ):
429
+ """Creates and configures a Dask client."""
430
+ return DaskClient (processes = False , asynchronous = True )
431
+
430
432
@property
431
433
def db_session (self ):
432
434
if not self ._db_session :
@@ -775,6 +777,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
775
777
776
778
return staging_paths
777
779
780
+ async def stop_extension (self ):
781
+ """
782
+ Cleanup code to run when the server is stopping.
783
+ """
784
+ if self .dask_client_future :
785
+ dask_client : DaskClient = await self .dask_client_future
786
+ await dask_client .close ()
787
+
778
788
779
789
class ArchivingScheduler (Scheduler ):
780
790
"""Scheduler that captures all files in output directory in an archive."""
0 commit comments