Skip to content

Commit ac145ec

Browse files
talagayevyuxuanzhuangorbeckst
authored
'MDAnalysis.analysis.density' parallelization (#4729)
* Fixes #4677 * parallelize DensityAnalysis * Changes made in this Pull Request: - added backends and aggregators to DensityAnalysis in analysis.density - added client_DensityAnalysis in conftest.py - added client_DensityAnalysis to the tests in test_density.py * Update CHANGELOG --------- Co-authored-by: Yuxuan Zhuang <yuxuan.zhuang@dbb.su.se> Co-authored-by: Oliver Beckstein <orbeckst@gmail.com>
1 parent a3672f2 commit ac145ec

File tree

4 files changed

+145
-62
lines changed

4 files changed

+145
-62
lines changed

package/CHANGELOG

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Fixes
2525
the function to prevent shared state. (Issue #4655)
2626

2727
Enhancements
28+
* Enables parallelization for analysis.density.DensityAnalysis (Issue #4677, PR #4729)
2829
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
2930
* Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670)
3031
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)

package/MDAnalysis/analysis/density.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
from ..lib import distances
170170
from MDAnalysis.lib.log import ProgressBar
171171

172-
from .base import AnalysisBase
172+
from .base import AnalysisBase, ResultsGroup
173173

174174
import logging
175175

@@ -395,8 +395,16 @@ class DensityAnalysis(AnalysisBase):
395395
:func:`_set_user_grid` is now a method of :class:`DensityAnalysis`.
396396
:class:`Density` results are now stored in a
397397
:class:`MDAnalysis.analysis.base.Results` instance.
398-
398+
.. versionchanged:: 2.9.0
399+
Introduced :meth:`get_supported_backends` allowing
400+
for parallel execution on :mod:`multiprocessing`
401+
and :mod:`dask` backends.
399402
"""
403+
_analysis_algorithm_is_parallelizable = True
404+
405+
@classmethod
406+
def get_supported_backends(cls):
407+
return ('serial', 'multiprocessing', 'dask')
400408

401409
def __init__(self, atomgroup, delta=1.0,
402410
metadata=None, padding=2.0,
@@ -412,7 +420,12 @@ def __init__(self, atomgroup, delta=1.0,
412420
self._ydim = ydim
413421
self._zdim = zdim
414422

415-
def _prepare(self):
423+
# The grid with its dimensions has to be set up in __init__
424+
# so that parallel analysis works correctly: each process
425+
# needs to have a results._grid of the same size and the
426+
# same self._bins and self._arange (so this cannot happen
427+
# in _prepare(), which is executed in parallel on different
428+
# parts of the trajectory).
416429
coord = self._atomgroup.positions
417430
if (self._gridcenter is not None or
418431
any([self._xdim, self._ydim, self._zdim])):
@@ -465,7 +478,7 @@ def _prepare(self):
465478
grid, edges = np.histogramdd(np.zeros((1, 3)), bins=bins,
466479
range=arange, density=False)
467480
grid *= 0.0
468-
self._grid = grid
481+
self.results._grid = grid
469482
self._edges = edges
470483
self._arange = arange
471484
self._bins = bins
@@ -474,21 +487,22 @@ def _single_frame(self):
474487
h, _ = np.histogramdd(self._atomgroup.positions,
475488
bins=self._bins, range=self._arange,
476489
density=False)
477-
# reduce (proposed change #2542 to match the parallel version in pmda.density)
478-
# return self._reduce(self._grid, h)
479-
#
480-
# serial code can simply do
481-
self._grid += h
490+
self.results._grid += h
482491

483492
def _conclude(self):
484493
# average:
485-
self._grid /= float(self.n_frames)
486-
density = Density(grid=self._grid, edges=self._edges,
494+
self.results._grid /= float(self.n_frames)
495+
density = Density(grid=self.results._grid, edges=self._edges,
487496
units={'length': "Angstrom"},
488497
parameters={'isDensity': False})
489498
density.make_density()
490499
self.results.density = density
491500

501+
def _get_aggregator(self):
502+
return ResultsGroup(lookup={
503+
'_grid': ResultsGroup.ndarray_sum}
504+
)
505+
492506
@property
493507
def density(self):
494508
wmsg = ("The `density` attribute was deprecated in MDAnalysis 2.0.0 "

testsuite/MDAnalysisTests/analysis/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from MDAnalysis.analysis.nucleicacids import NucPairDist
1818
from MDAnalysis.analysis.contacts import Contacts
19+
from MDAnalysis.analysis.density import DensityAnalysis
1920
from MDAnalysis.lib.util import is_installed
2021

2122

@@ -157,3 +158,10 @@ def client_NucPairDist(request):
157158
@pytest.fixture(scope="module", params=params_for_cls(Contacts))
158159
def client_Contacts(request):
159160
return request.param
161+
162+
163+
# MDAnalysis.analysis.density
164+
165+
@pytest.fixture(scope='module', params=params_for_cls(DensityAnalysis))
166+
def client_DensityAnalysis(request):
167+
return request.param

testsuite/MDAnalysisTests/analysis/test_density.py

Lines changed: 111 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,20 @@ def universe(self):
230230

231231

232232
class TestDensityAnalysis(DensityParameters):
233-
def check_DensityAnalysis(self, ag, ref_meandensity,
234-
tmpdir, runargs=None, **kwargs):
233+
def check_DensityAnalysis(
234+
self,
235+
ag,
236+
ref_meandensity,
237+
tmpdir,
238+
client_DensityAnalysis,
239+
runargs=None,
240+
**kwargs
241+
):
235242
runargs = runargs if runargs else {}
236243
with tmpdir.as_cwd():
237-
D = density.DensityAnalysis(
238-
ag, delta=self.delta, **kwargs).run(**runargs)
244+
D = density.DensityAnalysis(ag, delta=self.delta, **kwargs).run(
245+
**runargs, **client_DensityAnalysis
246+
)
239247
assert_almost_equal(D.results.density.grid.mean(), ref_meandensity,
240248
err_msg="mean density does not match")
241249
D.results.density.export(self.outfile)
@@ -247,125 +255,176 @@ def check_DensityAnalysis(self, ag, ref_meandensity,
247255
)
248256

249257
@pytest.mark.parametrize("mode", ("static", "dynamic"))
250-
def test_run(self, mode, universe, tmpdir):
258+
def test_run(self, mode, universe, tmpdir, client_DensityAnalysis):
251259
updating = (mode == "dynamic")
252260
self.check_DensityAnalysis(
253261
universe.select_atoms(self.selections[mode], updating=updating),
254262
self.references[mode]['meandensity'],
255-
tmpdir=tmpdir
263+
tmpdir=tmpdir,
264+
client_DensityAnalysis=client_DensityAnalysis,
256265
)
257266

258-
def test_sliced(self, universe, tmpdir):
267+
def test_sliced(self, universe, tmpdir, client_DensityAnalysis):
259268
self.check_DensityAnalysis(
260269
universe.select_atoms(self.selections['static']),
261270
self.references['static_sliced']['meandensity'],
262271
tmpdir=tmpdir,
272+
client_DensityAnalysis=client_DensityAnalysis,
263273
runargs=dict(start=1, stop=-1, step=2),
264274
)
265275

266-
def test_userdefn_eqbox(self, universe, tmpdir):
276+
def test_userdefn_eqbox(self, universe, tmpdir, client_DensityAnalysis):
267277
with warnings.catch_warnings():
268278
# Do not need to see UserWarning that box is too small
269279
warnings.simplefilter("ignore")
270280
self.check_DensityAnalysis(
271281
universe.select_atoms(self.selections['static']),
272282
self.references['static_defined']['meandensity'],
273283
tmpdir=tmpdir,
284+
client_DensityAnalysis=client_DensityAnalysis,
274285
gridcenter=self.gridcenters['static_defined'],
275286
xdim=10.0,
276287
ydim=10.0,
277288
zdim=10.0,
278289
)
279290

280-
def test_userdefn_neqbox(self, universe, tmpdir):
291+
def test_userdefn_neqbox(self, universe, tmpdir, client_DensityAnalysis):
281292
self.check_DensityAnalysis(
282293
universe.select_atoms(self.selections['static']),
283294
self.references['static_defined_unequal']['meandensity'],
284295
tmpdir=tmpdir,
296+
client_DensityAnalysis=client_DensityAnalysis,
285297
gridcenter=self.gridcenters['static_defined'],
286298
xdim=10.0,
287299
ydim=15.0,
288300
zdim=20.0,
289301
)
290302

291-
def test_userdefn_boxshape(self, universe):
303+
def test_userdefn_boxshape(self, universe, client_DensityAnalysis):
292304
D = density.DensityAnalysis(
293-
universe.select_atoms(self.selections['static']),
294-
delta=1.0, xdim=8.0, ydim=12.0, zdim=17.0,
295-
gridcenter=self.gridcenters['static_defined']).run()
305+
universe.select_atoms(self.selections["static"]),
306+
delta=1.0,
307+
xdim=8.0,
308+
ydim=12.0,
309+
zdim=17.0,
310+
gridcenter=self.gridcenters["static_defined"],
311+
).run(**client_DensityAnalysis)
296312
assert D.results.density.grid.shape == (8, 12, 17)
297313

298-
def test_warn_userdefn_padding(self, universe):
314+
def test_warn_userdefn_padding(self, universe, client_DensityAnalysis):
299315
regex = (r"Box padding \(currently set at 1\.0\) is not used "
300316
r"in user defined grids\.")
301317
with pytest.warns(UserWarning, match=regex):
302318
D = density.DensityAnalysis(
303-
universe.select_atoms(self.selections['static']),
304-
delta=self.delta, xdim=100.0, ydim=100.0, zdim=100.0, padding=1.0,
305-
gridcenter=self.gridcenters['static_defined']).run(step=5)
306-
307-
def test_warn_userdefn_smallgrid(self, universe):
319+
universe.select_atoms(self.selections["static"]),
320+
delta=self.delta,
321+
xdim=100.0,
322+
ydim=100.0,
323+
zdim=100.0,
324+
padding=1.0,
325+
gridcenter=self.gridcenters["static_defined"],
326+
).run(step=5, **client_DensityAnalysis)
327+
328+
def test_warn_userdefn_smallgrid(self, universe, client_DensityAnalysis):
308329
regex = ("Atom selection does not fit grid --- "
309330
"you may want to define a larger box")
310331
with pytest.warns(UserWarning, match=regex):
311332
D = density.DensityAnalysis(
312-
universe.select_atoms(self.selections['static']),
313-
delta=self.delta, xdim=1.0, ydim=2.0, zdim=2.0, padding=0.0,
314-
gridcenter=self.gridcenters['static_defined']).run(step=5)
315-
316-
def test_ValueError_userdefn_gridcenter_shape(self, universe):
333+
universe.select_atoms(self.selections["static"]),
334+
delta=self.delta,
335+
xdim=1.0,
336+
ydim=2.0,
337+
zdim=2.0,
338+
padding=0.0,
339+
gridcenter=self.gridcenters["static_defined"],
340+
).run(step=5, **client_DensityAnalysis)
341+
342+
def test_ValueError_userdefn_gridcenter_shape(
343+
self, universe, client_DensityAnalysis
344+
):
317345
# Test len(gridcenter) != 3
318346
with pytest.raises(ValueError, match="Gridcenter must be a 3D coordinate"):
319347
D = density.DensityAnalysis(
320-
universe.select_atoms(self.selections['static']),
321-
delta=self.delta, xdim=10.0, ydim=10.0, zdim=10.0,
322-
gridcenter=self.gridcenters['error1']).run(step=5)
348+
universe.select_atoms(self.selections["static"]),
349+
delta=self.delta,
350+
xdim=10.0,
351+
ydim=10.0,
352+
zdim=10.0,
353+
gridcenter=self.gridcenters["error1"],
354+
).run(step=5, **client_DensityAnalysis)
323355

324-
def test_ValueError_userdefn_gridcenter_type(self, universe):
356+
def test_ValueError_userdefn_gridcenter_type(
357+
self, universe, client_DensityAnalysis
358+
):
325359
# Test gridcenter includes non-numeric strings
326360
with pytest.raises(ValueError, match="Gridcenter must be a 3D coordinate"):
327361
D = density.DensityAnalysis(
328-
universe.select_atoms(self.selections['static']),
329-
delta=self.delta, xdim=10.0, ydim=10.0, zdim=10.0,
330-
gridcenter=self.gridcenters['error2']).run(step=5)
362+
universe.select_atoms(self.selections["static"]),
363+
delta=self.delta,
364+
xdim=10.0,
365+
ydim=10.0,
366+
zdim=10.0,
367+
gridcenter=self.gridcenters["error2"],
368+
).run(step=5, **client_DensityAnalysis)
331369

332-
def test_ValueError_userdefn_gridcenter_missing(self, universe):
370+
def test_ValueError_userdefn_gridcenter_missing(
371+
self, universe, client_DensityAnalysis
372+
):
333373
# Test no gridcenter provided when grid dimensions are given
334374
regex = ("Gridcenter or grid dimensions are not provided")
335375
with pytest.raises(ValueError, match=regex):
336376
D = density.DensityAnalysis(
337-
universe.select_atoms(self.selections['static']),
338-
delta=self.delta, xdim=10.0, ydim=10.0, zdim=10.0).run(step=5)
377+
universe.select_atoms(self.selections["static"]),
378+
delta=self.delta,
379+
xdim=10.0,
380+
ydim=10.0,
381+
zdim=10.0,
382+
).run(step=5, **client_DensityAnalysis)
339383

340-
def test_ValueError_userdefn_xdim_type(self, universe):
384+
def test_ValueError_userdefn_xdim_type(self, universe,
385+
client_DensityAnalysis):
341386
# Test xdim != int or float
342387
with pytest.raises(ValueError, match="xdim, ydim, and zdim must be numbers"):
343388
D = density.DensityAnalysis(
344-
universe.select_atoms(self.selections['static']),
345-
delta=self.delta, xdim="MDAnalysis", ydim=10.0, zdim=10.0,
346-
gridcenter=self.gridcenters['static_defined']).run(step=5)
389+
universe.select_atoms(self.selections["static"]),
390+
delta=self.delta,
391+
xdim="MDAnalysis",
392+
ydim=10.0,
393+
zdim=10.0,
394+
gridcenter=self.gridcenters["static_defined"],
395+
).run(step=5, **client_DensityAnalysis)
347396

348-
def test_ValueError_userdefn_xdim_nanvalue(self, universe):
397+
def test_ValueError_userdefn_xdim_nanvalue(self, universe,
398+
client_DensityAnalysis):
349399
# Test xdim set to NaN value
350400
regex = ("Gridcenter or grid dimensions have NaN element")
351401
with pytest.raises(ValueError, match=regex):
352402
D = density.DensityAnalysis(
353-
universe.select_atoms(self.selections['static']),
354-
delta=self.delta, xdim=np.nan, ydim=10.0, zdim=10.0,
355-
gridcenter=self.gridcenters['static_defined']).run(step=5)
403+
universe.select_atoms(self.selections["static"]),
404+
delta=self.delta,
405+
xdim=np.nan,
406+
ydim=10.0,
407+
zdim=10.0,
408+
gridcenter=self.gridcenters["static_defined"],
409+
).run(step=5, **client_DensityAnalysis)
356410

357-
def test_warn_noatomgroup(self, universe):
411+
def test_warn_noatomgroup(self, universe, client_DensityAnalysis):
358412
regex = ("No atoms in AtomGroup at input time frame. "
359413
"This may be intended; please ensure that "
360414
"your grid selection covers the atomic "
361415
"positions you wish to capture.")
362416
with pytest.warns(UserWarning, match=regex):
363417
D = density.DensityAnalysis(
364-
universe.select_atoms(self.selections['none']),
365-
delta=self.delta, xdim=1.0, ydim=2.0, zdim=2.0, padding=0.0,
366-
gridcenter=self.gridcenters['static_defined']).run(step=5)
367-
368-
def test_ValueError_noatomgroup(self, universe):
418+
universe.select_atoms(self.selections["none"]),
419+
delta=self.delta,
420+
xdim=1.0,
421+
ydim=2.0,
422+
zdim=2.0,
423+
padding=0.0,
424+
gridcenter=self.gridcenters["static_defined"],
425+
).run(step=5, **client_DensityAnalysis)
426+
427+
def test_ValueError_noatomgroup(self, universe, client_DensityAnalysis):
369428
with pytest.raises(ValueError, match="No atoms in AtomGroup at input"
370429
" time frame. Grid for density"
371430
" could not be automatically"
@@ -374,12 +433,13 @@ def test_ValueError_noatomgroup(self, universe):
374433
" defined grid will "
375434
"need to be provided instead."):
376435
D = density.DensityAnalysis(
377-
universe.select_atoms(self.selections['none'])).run(step=5)
436+
universe.select_atoms(self.selections["none"])
437+
).run(step=5, **client_DensityAnalysis)
378438

379-
def test_warn_results_deprecated(self, universe):
439+
def test_warn_results_deprecated(self, universe, client_DensityAnalysis):
380440
D = density.DensityAnalysis(
381441
universe.select_atoms(self.selections['static']))
382-
D.run(stop=1)
442+
D.run(stop=1, **client_DensityAnalysis)
383443
wmsg = "The `density` attribute was deprecated in MDAnalysis 2.0.0"
384444
with pytest.warns(DeprecationWarning, match=wmsg):
385445
assert_equal(D.density.grid, D.results.density.grid)

0 commit comments

Comments
 (0)