Skip to content

Commit d71eb98

Browse files
committed
enable new dask 0.18 scheduler selection idoms
- someone should check with dask. It seems a bit brittle - fix tests maybe - update documentation fixes #17
1 parent 8686767 commit d71eb98

File tree

7 files changed

+65
-45
lines changed

7 files changed

+65
-45
lines changed

conftest.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Released under the GNU Public Licence, v2 or any higher version
1010

1111
from dask import distributed
12+
import dask
1213
import pytest
1314

1415

@@ -24,9 +25,11 @@ def client(tmpdir_factory, request):
2425
lc.close()
2526

2627

27-
@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
28+
@pytest.fixture(scope='session', params=('distributed', 'multiprocessing', 'single-threaded'))
2829
def scheduler(request, client):
2930
if request.param == 'distributed':
30-
return client
31+
arg = client
3132
else:
32-
return request.param
33+
arg = request.param
34+
with dask.config.set(scheduler=arg):
35+
yield

docs/userguide/parallelization.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Internally, this uses the multiprocessing `scheduler`_ of dask. If you
1919
want to make use of more advanced scheduler features or scale your
2020
analysis to multiple nodes, e.g., in an HPC (high performance
2121
computing) environment, then use the :mod:`distributed` scheduler, as
22-
described next.
22+
described next. If ``n_jobs==1`` use a single threaded scheduler.
2323

2424
.. _`scheduler`:
2525
https://dask.pydata.org/en/latest/scheduler-overview.html
@@ -58,7 +58,7 @@ use the :ref:`RMSD example<example-parallel-rmsd>`):
5858

5959
.. code:: python
6060
61-
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run(scheduler=client)
61+
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run()
6262
6363
Because the local cluster contains 8 workers, the RMSD trajectory
6464
analysis will be parallelized over 8 trajectory segments.
@@ -78,7 +78,7 @@ analysis :meth:`~pmda.parallel.ParallelAnalysisBase.run` method:
7878
7979
import distributed
8080
client = distributed.Client('192.168.0.1:8786')
81-
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run(scheduler=client)
81+
rmsd_ana = rms.RMSD(u.atoms, ref.atoms).run()
8282
8383
In this way one can spread an analysis task over many different nodes.
8484

docs/userguide/pmda_classes.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ are provided as keyword arguments:
1717

1818
set up the parallel analysis
1919

20-
.. method:: run(n_jobs=-1, scheduler=None)
20+
.. method:: run(n_jobs=-1)
2121

2222
perform parallel analysis; see :ref:`parallelization`
2323
for explanation of the arguments

pmda/leaflet.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def run(self,
231231
start=None,
232232
stop=None,
233233
step=None,
234-
scheduler=None,
235234
n_jobs=-1,
236235
cutoff=15.0):
237236
"""Perform the calculation
@@ -244,35 +243,53 @@ def run(self,
244243
stop frame of analysis
245244
step : int, optional
246245
number of frames to skip between each analysed frame
247-
scheduler : dask scheduler, optional
248-
Use dask scheduler, defaults to multiprocessing. This can be used
249-
to spread work to a distributed scheduler
250246
n_jobs : int, optional
251247
number of tasks to start, if `-1` use number of logical cpu cores.
252248
This argument will be ignored when the distributed scheduler is
253249
used
254250
255251
"""
256-
if scheduler is None:
252+
# are we using a distributed scheduler or should we use multiprocessing?
253+
scheduler = dask.config.get('scheduler', None)
254+
if scheduler is None and client is None:
257255
scheduler = 'multiprocessing'
256+
elif scheduler is None:
257+
# maybe we can grab a global worker
258+
try:
259+
from dask import distributed
260+
scheduler = distributed.worker.get_client()
261+
except ValueError:
262+
pass
263+
except ImportError:
264+
pass
258265

259266
if n_jobs == -1:
267+
n_jobs = cpu_count()
268+
269+
# we could not find a global scheduler to use and we ask for a single
270+
# job. Therefore we run this on the single threaded scheduler for
271+
# debugging.
272+
if scheduler is None and n_jobs == 1:
273+
scheduler = 'single-threaded'
274+
275+
if n_blocks is None:
260276
if scheduler == 'multiprocessing':
261-
n_jobs = cpu_count()
277+
n_blocks = n_jobs
262278
elif isinstance(scheduler, distributed.Client):
263-
n_jobs = len(scheduler.ncores())
279+
n_blocks = len(scheduler.ncores())
264280
else:
265-
raise ValueError(
266-
"Couldn't guess ideal number of jobs from scheduler."
267-
"Please provide `n_jobs` in call to method.")
268-
269-
with timeit() as b_universe:
270-
universe = mda.Universe(self._top, self._traj)
281+
n_blocks = 1
282+
warnings.warn(
283+
"Couldn't guess ideal number of blocks from scheduler. Set n_blocks=1"
284+
"Please provide `n_blocks` in call to method.")
271285

272286
scheduler_kwargs = {'scheduler': scheduler}
273287
if scheduler == 'multiprocessing':
274288
scheduler_kwargs['num_workers'] = n_jobs
275289

290+
with timeit() as b_universe:
291+
universe = mda.Universe(self._top, self._traj)
292+
276293
start, stop, step = self._trajectory.check_slice_indices(
277294
start, stop, step)
278295
with timeit() as total:

pmda/parallel.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from six.moves import range
2222

2323
import MDAnalysis as mda
24-
from dask import distributed
2524
from dask.delayed import delayed
2625
from joblib import cpu_count
2726
import numpy as np
@@ -267,7 +266,6 @@ def run(self,
267266
start=None,
268267
stop=None,
269268
step=None,
270-
scheduler=None,
271269
n_jobs=1,
272270
n_blocks=None):
273271
"""Perform the calculation
@@ -280,9 +278,6 @@ def run(self,
280278
stop frame of analysis
281279
step : int, optional
282280
number of frames to skip between each analysed frame
283-
scheduler : dask scheduler, optional
284-
Use dask scheduler, defaults to multiprocessing. This can be used
285-
to spread work to a distributed scheduler
286281
n_jobs : int, optional
287282
number of jobs to start, if `-1` use number of logical cpu cores.
288283
This argument will be ignored when the distributed scheduler is
@@ -292,20 +287,38 @@ def run(self,
292287
to n_jobs or number of available workers in scheduler.
293288
294289
"""
295-
if scheduler is None:
290+
# are we using a distributed scheduler or should we use multiprocessing?
291+
scheduler = dask.config.get('scheduler', None)
292+
if scheduler is None and client is None:
296293
scheduler = 'multiprocessing'
294+
elif scheduler is None:
295+
# maybe we can grab a global worker
296+
try:
297+
from dask import distributed
298+
scheduler = distributed.worker.get_client()
299+
except ValueError:
300+
pass
301+
except ImportError:
302+
pass
297303

298304
if n_jobs == -1:
299305
n_jobs = cpu_count()
300306

307+
# we could not find a global scheduler to use and we ask for a single
308+
# job. Therefore we run this on the single threaded scheduler for
309+
# debugging.
310+
if scheduler is None and n_jobs == 1:
311+
scheduler = 'single-threaded'
312+
301313
if n_blocks is None:
302314
if scheduler == 'multiprocessing':
303315
n_blocks = n_jobs
304316
elif isinstance(scheduler, distributed.Client):
305317
n_blocks = len(scheduler.ncores())
306318
else:
307-
raise ValueError(
308-
"Couldn't guess ideal number of blocks from scheduler."
319+
n_blocks = 1
320+
warnings.warn(
321+
"Couldn't guess ideal number of blocks from scheduler. Set n_blocks=1"
309322
"Please provide `n_blocks` in call to method.")
310323

311324
scheduler_kwargs = {'scheduler': scheduler}

pmda/test/test_custom.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ def test_AnalysisFromFunction(scheduler):
2727
u = mda.Universe(PSF, DCD)
2828
step = 2
2929
ana1 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
30-
step=step, scheduler=scheduler
30+
step=step
3131
)
3232
ana2 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
33-
step=step, scheduler=scheduler
33+
step=step
3434
)
3535
ana3 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
36-
step=step, scheduler=scheduler
36+
step=step
3737
)
3838

3939
results = []

pmda/test/test_parallel.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,6 @@ def analysis():
6060
return ana
6161

6262

63-
def test_wrong_scheduler(analysis):
64-
with pytest.raises(ValueError):
65-
analysis.run(scheduler=2)
66-
67-
6863
@pytest.mark.parametrize('n_jobs', (1, 2))
6964
def test_all_frames(analysis, n_jobs):
7065
analysis.run(n_jobs=n_jobs)
@@ -91,16 +86,8 @@ def test_no_frames(analysis, n_jobs):
9186
assert analysis.timing.universe == 0
9287

9388

94-
@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
95-
def scheduler(request, client):
96-
if request.param == 'distributed':
97-
return client
98-
else:
99-
return request.param
100-
101-
10289
def test_scheduler(analysis, scheduler):
103-
analysis.run(scheduler=scheduler)
90+
analysis.run()
10491

10592

10693
def test_nframes_less_nblocks_warning(analysis):

0 commit comments

Comments
 (0)