-
Notifications
You must be signed in to change notification settings - Fork 29
Use Dask instead of multiprocessing module #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d85820
320be26
fb67a58
61c832d
a80f688
b02f850
f4d8f8a
74aaf9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
import multiprocessing as mp | ||
import asyncio | ||
import os | ||
import random | ||
import shutil | ||
from typing import Dict, List, Optional, Type, Union | ||
|
||
import fsspec | ||
import psutil | ||
from dask.distributed import Client as DaskClient | ||
from distributed import LocalCluster | ||
from jupyter_core.paths import jupyter_data_dir | ||
from jupyter_server.transutils import _i18n | ||
from jupyter_server.utils import to_os_path | ||
|
@@ -381,6 +383,12 @@ def get_local_output_path( | |
else: | ||
return os.path.join(self.root_dir, self.output_directory, output_dir_name) | ||
|
||
async def stop_extension(self): | ||
""" | ||
Placeholder method for a cleanup code to run when the server is stopping. | ||
""" | ||
pass | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to keep this method if it does nothing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I implemented |
||
|
||
class Scheduler(BaseScheduler): | ||
_db_session = None | ||
|
@@ -395,6 +403,12 @@ class Scheduler(BaseScheduler): | |
), | ||
) | ||
|
||
dask_cluster_url = Unicode( | ||
allow_none=True, | ||
config=True, | ||
help="URL of the Dask cluster to connect to.", | ||
) | ||
|
||
db_url = Unicode(help=_i18n("Scheduler database url")) | ||
|
||
task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner") | ||
|
@@ -414,6 +428,15 @@ def __init__( | |
if self.task_runner_class: | ||
self.task_runner = self.task_runner_class(scheduler=self, config=config) | ||
|
||
self.dask_client: DaskClient = self._get_dask_client() | ||
|
||
def _get_dask_client(self): | ||
"""Creates and configures a Dask client.""" | ||
if self.dask_cluster_url: | ||
return DaskClient(self.dask_cluster_url) | ||
cluster = LocalCluster(processes=True) | ||
return DaskClient(cluster) | ||
|
||
@property | ||
def db_session(self): | ||
if not self._db_session: | ||
|
@@ -478,25 +501,16 @@ def create_job(self, model: CreateJob) -> str: | |
else: | ||
self.copy_input_file(model.input_uri, staging_paths["input"]) | ||
|
||
# The MP context forces new processes to not be forked on Linux. | ||
# This is necessary because `asyncio.get_event_loop()` is bugged in | ||
# forked processes in Python versions below 3.12. This method is | ||
# called by `jupyter_core` by `nbconvert` in the default executor. | ||
# | ||
# See: https://github.com/python/cpython/issues/66285 | ||
# See also: https://github.com/jupyter/jupyter_core/pull/362 | ||
mp_ctx = mp.get_context("spawn") | ||
p = mp_ctx.Process( | ||
target=self.execution_manager_class( | ||
future = self.dask_client.submit( | ||
self.execution_manager_class( | ||
job_id=job.job_id, | ||
staging_paths=staging_paths, | ||
root_dir=self.root_dir, | ||
db_url=self.db_url, | ||
).process | ||
) | ||
p.start() | ||
|
||
job.pid = p.pid | ||
job.pid = future.key | ||
session.commit() | ||
|
||
job_id = job.job_id | ||
|
@@ -777,6 +791,13 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) -> | |
|
||
return staging_paths | ||
|
||
async def stop_extension(self): | ||
""" | ||
Cleanup code to run when the server is stopping. | ||
""" | ||
if self.dask_client: | ||
await self.dask_client.close() | ||
|
||
|
||
class ArchivingScheduler(Scheduler): | ||
"""Scheduler that captures all files in output directory in an archive.""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.