Skip to content

Commit 12993c1

Browse files
committed
add stop_extension logic, use it for stopping dask
1 parent 572faca commit 12993c1

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

jupyter_scheduler/extension.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,25 @@ def initialize_settings(self):
9292
if scheduler.task_runner:
9393
loop = asyncio.get_event_loop()
9494
loop.create_task(scheduler.task_runner.start())
95+
96+
async def stop_extension(self):
97+
"""
98+
Public method called by Jupyter Server when the server is stopping.
99+
This calls the cleanup code defined in `self._stop_exception()` inside
100+
an exception handler, as the server halts if this method raises an
101+
exception.
102+
"""
103+
try:
104+
await self._stop_extension()
105+
except Exception as e:
106+
self.log.error("Jupyter Scheduler raised an exception while stopping:")
107+
self.log.exception(e)
108+
109+
async def _stop_extension(self):
110+
"""
111+
Private method that defines the cleanup code to run when the server is
112+
stopping.
113+
"""
114+
if "scheduler" in self.settings:
115+
scheduler: SchedulerApp = self.settings["scheduler"]
116+
await scheduler.stop_extension()

jupyter_scheduler/scheduler.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,12 @@ def _default_staging_path(self):
9898
)
9999

100100
def __init__(
101-
self,
102-
root_dir: str,
103-
environments_manager: Type[EnvironmentManager],
104-
config=None,
105-
**kwargs,
101+
self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs
106102
):
107103
super().__init__(config=config, **kwargs)
108104
self.root_dir = root_dir
109105
self.environments_manager = environments_manager
110106

111-
loop = asyncio.get_event_loop()
112-
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
113-
114-
async def _get_dask_client(self):
115-
"""Creates and configures a Dask client."""
116-
return DaskClient(processes=False, asynchronous=True)
117-
118107
def create_job(self, model: CreateJob) -> str:
119108
"""Creates a new job record, may trigger execution of the job.
120109
In case a task runner is actually handling execution of the jobs,
@@ -394,6 +383,12 @@ def get_local_output_path(
394383
else:
395384
return os.path.join(self.root_dir, self.output_directory, output_dir_name)
396385

386+
async def stop_extension(self):
387+
"""
388+
Placeholder method for a cleanup code to run when the server is stopping.
389+
"""
390+
pass
391+
397392

398393
class Scheduler(BaseScheduler):
399394
_db_session = None
@@ -427,6 +422,13 @@ def __init__(
427422
if self.task_runner_class:
428423
self.task_runner = self.task_runner_class(scheduler=self, config=config)
429424

425+
loop = asyncio.get_event_loop()
426+
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
427+
428+
async def _get_dask_client(self):
429+
"""Creates and configures a Dask client."""
430+
return DaskClient(processes=False, asynchronous=True)
431+
430432
@property
431433
def db_session(self):
432434
if not self._db_session:
@@ -775,6 +777,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
775777

776778
return staging_paths
777779

780+
async def stop_extension(self):
781+
"""
782+
Cleanup code to run when the server is stopping.
783+
"""
784+
if self.dask_client_future:
785+
dask_client: DaskClient = await self.dask_client_future
786+
await dask_client.close()
787+
778788

779789
class ArchivingScheduler(Scheduler):
780790
"""Scheduler that captures all files in output directory in an archive."""

0 commit comments

Comments
 (0)