Skip to content

Commit 6aa6c61

Browse files
authored
Merge pull request #75 from MDAnalysis/issue-71-block-assign
balance block sizes (#71)
2 parents a718250 + 4a04788 commit 6aa6c61

9 files changed

+319
-16
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ env:
1515
# Set default python version to avoid repetition later
1616
- PYTHON_VERSION=3.6
1717
- MAIN_CMD="pytest"
18-
- SETUP_CMD="pmda --pep8 -v --cov pmda"
18+
- SETUP_CMD="pmda --pep8 --cov pmda"
1919
# mdanalysis develop from source (see below), which needs
2020
# minimal CONDA_MDANALYSIS_DEPENDENCIES
2121
#- CONDA_DEPENDENCIES="mdanalysis mdanalysistests dask joblib pytest-pep8 mock codecov cython hypothesis sphinx"

CHANGELOG

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The rules for this file:
1313
* release numbers follow "Semantic Versioning" http://semver.org
1414

1515
------------------------------------------------------------------------------
16-
xx/xx/18 VOD555, richardjgowers, iparask
16+
xx/xx/18 VOD555, richardjgowers, iparask, orbeckst
1717

1818
* 0.2.0
1919

@@ -23,6 +23,10 @@ Enhancements
2323
* add readonly_attributes context manager to ParallelAnalysisBase
2424
* add parallel implementation of Leaflet Finder (Issue #47)
2525

26+
Fixes
27+
* always distribute frames over blocks so that no empty blocks are
28+
created ("balanced blocks", Issue #71)
29+
2630

2731
06/07/18 orbeckst
2832

conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dask import distributed, multiprocessing
1212
import pytest
1313

14+
1415
@pytest.fixture(scope="session", params=(1, 2))
1516
def client(tmpdir_factory, request):
1617
with tmpdir_factory.mktemp("dask_cluster").as_cwd():

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ a single frame. If your need more flexibility you can use the
2626

2727
api/parallel
2828
api/custom
29+
api/util
2930

3031
.. _pre-defined-analysis-tasks:
3132

docs/api/util.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.. automodule:: pmda.util
2+
:members:

pmda/parallel.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"""
1717
from __future__ import absolute_import, division
1818
from contextlib import contextmanager
19+
import warnings
20+
1921
from six.moves import range
2022

2123
import MDAnalysis as mda
@@ -24,7 +26,7 @@
2426
from joblib import cpu_count
2527
import numpy as np
2628

27-
from .util import timeit
29+
from .util import timeit, make_balanced_slices
2830

2931

3032
class Timing(object):
@@ -313,26 +315,36 @@ def run(self,
313315
start, stop, step = self._trajectory.check_slice_indices(
314316
start, stop, step)
315317
n_frames = len(range(start, stop, step))
316-
bsize = int(np.ceil(n_frames / float(n_blocks)))
318+
319+
if n_frames == 0:
320+
warnings.warn("run() analyses no frames: check start/stop/step")
321+
if n_frames < n_blocks:
322+
warnings.warn("run() uses more blocks than frames: "
323+
"decrease n_blocks")
324+
325+
slices = make_balanced_slices(n_frames, n_blocks,
326+
start=start, stop=stop, step=step)
317327

318328
with timeit() as total:
319329
with timeit() as prepare:
320330
self._prepare()
321331
time_prepare = prepare.elapsed
322332
blocks = []
323333
with self.readonly_attributes():
324-
for b in range(n_blocks):
334+
for bslice in slices:
325335
task = delayed(
326336
self._dask_helper, pure=False)(
327-
b * bsize * step + start,
328-
min(stop, (b + 1) * bsize * step + start),
329-
step,
337+
bslice,
330338
self._indices,
331339
self._top,
332340
self._traj, )
333341
blocks.append(task)
334342
blocks = delayed(blocks)
335343
res = blocks.compute(**scheduler_kwargs)
344+
# hack to handle n_frames == 0 in this framework
345+
if len(res) == 0:
346+
# everything else wants list of block tuples
347+
res = [([], [], [], 0)]
336348
self._results = np.asarray([el[0] for el in res])
337349
with timeit() as conclude:
338350
self._conclude()
@@ -343,7 +355,7 @@ def run(self,
343355
np.array([el[3] for el in res]), time_prepare, conclude.elapsed)
344356
return self
345357

346-
def _dask_helper(self, start, stop, step, indices, top, traj):
358+
def _dask_helper(self, bslice, indices, top, traj):
347359
"""helper function to actually setup dask graph"""
348360
with timeit() as b_universe:
349361
u = mda.Universe(top, traj)
@@ -352,8 +364,12 @@ def _dask_helper(self, start, stop, step, indices, top, traj):
352364
res = []
353365
times_io = []
354366
times_compute = []
355-
for i in range(start, stop, step):
367+
# NOTE: bslice.stop cannot be None! Always make sure
368+
# that it comes from _trajectory.check_slice_indices()!
369+
for i in range(bslice.start, bslice.stop, bslice.step):
356370
with timeit() as b_io:
371+
# explicit instead of 'for ts in u.trajectory[bslice]'
372+
# so that we can get accurate timing.
357373
ts = u.trajectory[i]
358374
with timeit() as b_compute:
359375
res = self._reduce(res, self._single_frame(ts, agroups))

pmda/test/test_parallel.py

+21
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,27 @@ def test_sub_frames(analysis, n_jobs):
7878
np.testing.assert_almost_equal(analysis.res, [10, 20, 30, 40])
7979

8080

81+
@pytest.mark.parametrize('n_jobs', (1, 2, 3))
82+
def test_no_frames(analysis, n_jobs):
83+
u = mda.Universe(analysis._top, analysis._traj)
84+
n_frames = u.trajectory.n_frames
85+
with pytest.warns(UserWarning):
86+
analysis.run(start=n_frames, stop=n_frames+1, n_jobs=n_jobs)
87+
assert len(analysis.res) == 0
88+
np.testing.assert_equal(analysis.res, [])
89+
np.testing.assert_equal(analysis.timing.compute, [])
90+
np.testing.assert_equal(analysis.timing.io, [])
91+
assert analysis.timing.universe == 0
92+
93+
94+
def test_nframes_less_nblocks_warning(analysis):
95+
u = mda.Universe(analysis._top, analysis._traj)
96+
n_frames = u.trajectory.n_frames
97+
with pytest.warns(UserWarning):
98+
analysis.run(stop=2, n_blocks=4, n_jobs=2)
99+
assert len(analysis.res) == 2
100+
101+
81102
def test_scheduler(analysis, scheduler):
82103
analysis.run(scheduler=scheduler)
83104

pmda/test/test_util.py

+113-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,125 @@
88
# Released under the GNU Public Licence, v2 or any higher version
99
from __future__ import absolute_import
1010

11+
from six.moves import range
12+
13+
import pytest
14+
1115
import time
12-
from numpy.testing import assert_almost_equal
16+
import numpy as np
17+
from numpy.testing import assert_almost_equal, assert_equal
1318

14-
from pmda.util import timeit
19+
from pmda.util import timeit, make_balanced_slices
1520

1621

1722
def test_timeit():
1823
with timeit() as timer:
1924
time.sleep(1)
2025

2126
assert_almost_equal(timer.elapsed, 1, decimal=2)
27+
28+
29+
@pytest.mark.parametrize("start", (None, 0, 1, 10))
30+
@pytest.mark.parametrize("n_frames,n_blocks,result", [
31+
(5, 1, [slice(0, None, 1)]),
32+
(5, 2, [slice(0, 3, 1), slice(3, None, 1)]),
33+
(5, 3, [slice(0, 2, 1), slice(2, 4, 1), slice(4, None, 1)]),
34+
(5, 4, [slice(0, 2, 1), slice(2, 3, 1), slice(3, 4, 1),
35+
slice(4, None, 1)]),
36+
(5, 5, [slice(0, 1, 1), slice(1, 2, 1), slice(2, 3, 1), slice(3, 4, 1),
37+
slice(4, None, 1)]),
38+
(10, 2, [slice(0, 5, 1), slice(5, None, 1)]),
39+
(10, 3, [slice(0, 4, 1), slice(4, 7, 1), slice(7, None, 1)]),
40+
(10, 7, [slice(0, 2, 1), slice(2, 4, 1), slice(4, 6, 1), slice(6, 7, 1),
41+
slice(7, 8, 1), slice(8, 9, 1), slice(9, None, 1)]),
42+
])
43+
def test_make_balanced_slices_step1(n_frames, n_blocks, start, result, step=1):
44+
assert step in (None, 1), "This test can only test step None or 1"
45+
46+
_start = start if start is not None else 0
47+
_result = [slice(sl.start + _start,
48+
sl.stop + _start if sl.stop is not None else None,
49+
sl.step) for sl in result]
50+
51+
slices = make_balanced_slices(n_frames, n_blocks,
52+
start=start, step=step)
53+
assert_equal(slices, _result)
54+
55+
56+
def _test_make_balanced_slices(n_blocks, start, stop, step, scale):
57+
_start = start if start is not None else 0
58+
59+
traj_frames = range(scale * stop)
60+
frames = traj_frames[start:stop:step]
61+
n_frames = len(frames)
62+
63+
slices = make_balanced_slices(n_frames, n_blocks,
64+
start=start, stop=stop, step=step)
65+
66+
assert len(slices) == n_blocks
67+
68+
# assemble frames again by blocks and show that we have all
69+
# the original frames; get the sizes of the blocks
70+
71+
block_frames = []
72+
block_sizes = []
73+
for bslice in slices:
74+
bframes = traj_frames[bslice]
75+
block_frames.extend(list(bframes))
76+
block_sizes.append(len(bframes))
77+
block_sizes = np.array(block_sizes)
78+
79+
# check that we have all the frames accounted for
80+
assert_equal(np.asarray(block_frames), np.asarray(frames))
81+
82+
# check that the distribution is balanced
83+
if n_frames >= n_blocks:
84+
assert np.all(block_sizes > 0)
85+
minsize = n_frames // n_blocks
86+
assert not np.setdiff1d(block_sizes, [minsize, minsize+1]), \
87+
"For n_blocks <= n_frames, block sizes are not balanced"
88+
else:
89+
# pathological case; we will have blocks with length 0
90+
# and n_blocks with 1 frame
91+
zero_blocks = block_sizes == 0
92+
assert np.sum(zero_blocks) == n_blocks - n_frames
93+
assert np.sum(~zero_blocks) == n_frames
94+
assert not np.setdiff1d(block_sizes[~zero_blocks], [1]), \
95+
"For n_blocks>n_frames, some blocks contain != 1 frame"
96+
97+
98+
@pytest.mark.parametrize('n_blocks', [1, 2, 3, 4, 5, 7, 10, 11])
99+
@pytest.mark.parametrize('start', [0, 1, 10])
100+
@pytest.mark.parametrize('stop', [11, 100, 256])
101+
@pytest.mark.parametrize('step', [None, 1, 2, 3, 5, 7])
102+
@pytest.mark.parametrize('scale', [1, 2])
103+
def test_make_balanced_slices(n_blocks, start, stop, step, scale):
104+
return _test_make_balanced_slices(n_blocks, start, stop, step, scale)
105+
106+
107+
def test_make_balanced_slices_step_gt_stop(n_blocks=2, start=None,
108+
stop=5, step=6, scale=1):
109+
return _test_make_balanced_slices(n_blocks, start, stop, step, scale)
110+
111+
112+
@pytest.mark.parametrize('n_blocks', [1, 2])
113+
@pytest.mark.parametrize('start', [0, 10])
114+
@pytest.mark.parametrize('step', [None, 1, 2])
115+
def test_make_balanced_slices_empty(n_blocks, start, step):
116+
slices = make_balanced_slices(0, n_blocks, start=start, step=step)
117+
assert slices == []
118+
119+
120+
@pytest.mark.parametrize("n_frames,n_blocks,start,stop,step",
121+
[(-1, 5, None, None, None), (5, 0, None, None, None),
122+
(5, -1, None, None, None), (0, 0, None, None, None),
123+
(-1, -1, None, None, None),
124+
(5, 4, -1, None, None), (0, 5, -1, None, None),
125+
(5, 0, -1, None, None),
126+
(5, 4, None, -1, None), (5, 4, 3, 2, None),
127+
(5, 4, None, None, -1), (5, 4, None, None, 0)])
128+
def test_make_balanced_slices_ValueError(n_frames, n_blocks,
129+
start, stop, step):
130+
with pytest.raises(ValueError):
131+
make_balanced_slices(n_frames, n_blocks,
132+
start=start, stop=stop, step=step)

0 commit comments

Comments
 (0)