Skip to content

Commit d64c7d6

Browse files
authored
Merge pull request #102 from jrmaddison/jrmaddison/MPI
Use mpi4py.MPI instead of dolfin.MPI
2 parents 4def7d0 + f8b1348 commit d64c7d6

14 files changed

+43
-36
lines changed

aux/Uobs_from_momsolve.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

18-
from fenics_ice.backend import Function, HDF5File, MPI, Mesh, Point, \
18+
from fenics_ice.backend import Function, HDF5File, Mesh, Point, \
1919
RectangleMesh, VectorFunctionSpace, project
2020

21+
import mpi4py.MPI as MPI # noqa: N817
2122
import numpy as np
2223
from pathlib import Path
2324
import h5py
@@ -51,7 +52,7 @@ def main(dd, infile, outfile, noise_sdev, L, seed=0, ls=None):
5152
assert Path(outfile).suffix == ".h5"
5253
assert noise_sdev > 0.0
5354

54-
infile = HDF5File(MPI.comm_world, str(Path(dd)/infile), 'r')
55+
infile = HDF5File(MPI.COMM_WORLD, str(Path(dd)/infile), 'r')
5556

5657
# Get mesh from file
5758
mesh = Mesh()

aux/gen_rect_mesh.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from fenics_ice.backend import File, Point, RectangleMesh
2424

25-
from mpi4py import MPI
25+
import mpi4py.MPI as MPI # noqa: N817
2626
import argparse
2727

2828
def gen_rect_mesh(nx, ny, xmin, xmax, ymin, ymax, outfile, direction='right'):

aux/plotting/plot_dq_ts.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

18-
from fenics_ice.backend import Function, FunctionSpace, HDF5File, Mesh, MPI
18+
from fenics_ice.backend import Function, FunctionSpace, HDF5File, Mesh
1919

20+
import mpi4py.MPI as MPI # noqa: N817
2021
import sys
2122
import pickle
2223
import numpy as np
@@ -93,7 +94,7 @@
9394
t = mesh.cells()
9495

9596
hdffile = str(next(outdir.glob("*dQ_ts.h5")))
96-
hdf5data = HDF5File(MPI.comm_world, hdffile, 'r')
97+
hdf5data = HDF5File(MPI.COMM_WORLD, hdffile, 'r')
9798
hdf5data.read(dQ, f'dQ/vector_{n_sens}')
9899

99100

aux/plotting/plot_leading_eigenfuncs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

18-
from fenics_ice.backend import Function, FunctionSpace, HDF5File, MPI, Mesh
18+
from fenics_ice.backend import Function, FunctionSpace, HDF5File, Mesh
1919

20+
import mpi4py.MPI as MPI # noqa: N817
2021
import sys
2122
import pickle
2223
import numpy as np
@@ -91,7 +92,7 @@
9192
for j in range(4):
9293
k = j + e_offset
9394
vr_file = str(next((run_dir / 'output').glob("*vr.h5")))
94-
hdf5data = HDF5File(MPI.comm_world, vr_file, 'r')
95+
hdf5data = HDF5File(MPI.COMM_WORLD, vr_file, 'r')
9596
hdf5data.read(eigenfunc, f'v/vector_{k}')
9697

9798
sind = j+1+i*4

fenics_ice/inout.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
Module to handle model input & output
2020
"""
2121

22-
from .backend import File, Function, HDF5File, MPI, XDMFFile, \
22+
from .backend import File, Function, HDF5File, XDMFFile, \
2323
configure_checkpointing
2424
from tlm_adjoint.fenics.backend import backend_Function
2525

26+
import mpi4py.MPI as MPI # noqa: N817
2627
import sys
2728
import time
2829
import csv
@@ -48,7 +49,7 @@ class Writer(ABC):
4849
unnamed_re = re.compile("f_[0-9]+")
4950
suffix = None # the subclass specific file extension
5051

51-
def __init__(self, fpath, comm=MPI.comm_world):
52+
def __init__(self, fpath, comm=MPI.COMM_WORLD):
5253
assert comm is not None, "Need an MPI communicator"
5354
self._fpath = Path(fpath)
5455
self.comm = comm
@@ -245,7 +246,7 @@ def write_dqval(dQ_ts, cntrl_names, params):
245246
outdir_f = Path(outdir)/phase_name/phase_suffix
246247
# TODO add this file to diags once Dan makes his pull request
247248
vtkfile = File(str((diagdir_f/h5_filename).with_suffix(".pvd")))
248-
hdf5out = HDF5File(MPI.comm_world, str(outdir_f/h5_filename), 'w')
249+
hdf5out = HDF5File(MPI.COMM_WORLD, str(outdir_f/h5_filename), 'w')
249250
n = 0.0
250251

251252
# Loop dQ sample times ('num_sens')
@@ -304,15 +305,15 @@ def write_variable(var, params, name=None, outdir=None, phase_name='', phase_suf
304305
xml_fname = str(outfname.with_suffix(".xml"))
305306
File(xml_fname) << outvar
306307
if 'h5' in output_var_format:
307-
hdf5out = HDF5File(MPI.comm_world, str(outfname.with_suffix(".h5")), 'w')
308+
hdf5out = HDF5File(MPI.COMM_WORLD, str(outfname.with_suffix(".h5")), 'w')
308309
hdf5out.write(outvar, name)
309310
hdf5out.close()
310311
if 'all' in output_var_format:
311312
vtk_fname = str(outfname.with_suffix(".pvd"))
312313
xml_fname = str(outfname.with_suffix(".xml"))
313314
File(vtk_fname) << outvar
314315
File(xml_fname) << outvar
315-
hdf5out = HDF5File(MPI.comm_world, str(outfname.with_suffix(".h5")), 'w')
316+
hdf5out = HDF5File(MPI.COMM_WORLD, str(outfname.with_suffix(".h5")), 'w')
316317
hdf5out.write(outvar, name)
317318
hdf5out.close()
318319

fenics_ice/mesh.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
Module to handle all things mesh to avoid code repetition in run scripts.
2020
"""
2121

22-
from .backend import FunctionSpace, MPI, Mesh, MeshFunction, \
23-
MeshValueCollection, VectorFunctionSpace, XDMFFile, parameters
22+
from .backend import FunctionSpace, Mesh, MeshFunction, MeshValueCollection, \
23+
VectorFunctionSpace, XDMFFile, parameters
2424

2525
from . import model
2626

27+
import mpi4py.MPI as MPI # noqa: N817
2728
import os
2829
import numpy as np
2930
from pathlib import Path
@@ -50,7 +51,7 @@ def get_mesh(params):
5051

5152
elif filetype == '.xdmf':
5253
mesh_in = Mesh()
53-
mesh_xdmf = XDMFFile(MPI.comm_world, str(meshfile))
54+
mesh_xdmf = XDMFFile(MPI.COMM_WORLD, str(meshfile))
5455
mesh_xdmf.read(mesh_in)
5556

5657
else:
@@ -68,10 +69,10 @@ def get_mesh_length(mesh):
6869

6970
comm = mesh.mpi_comm()
7071

71-
xmin = MPI.min(comm, xmin)
72-
xmax = MPI.max(comm, xmax)
73-
ymin = MPI.min(comm, ymin)
74-
ymax = MPI.max(comm, ymax)
72+
xmin = comm.allreduce(xmin, op=MPI.MIN)
73+
xmax = comm.allreduce(xmax, op=MPI.MAX)
74+
ymin = comm.allreduce(ymin, op=MPI.MIN)
75+
ymax = comm.allreduce(ymax, op=MPI.MAX)
7576

7677
L1 = xmax - xmin
7778
L2 = ymax - ymin
@@ -128,7 +129,7 @@ def get_ff_from_file(params, model, fill_val=0):
128129

129130
# Read the MeshValueCollection (sparse)
130131
ff_mvc = MeshValueCollection("size_t", model.mesh, dim=dim-1)
131-
ff_xdmf = XDMFFile(MPI.comm_world, str(ff_file))
132+
ff_xdmf = XDMFFile(MPI.COMM_WORLD, str(ff_file))
132133
ff_xdmf.read(ff_mvc)
133134

134135
# Create FacetFunction filled w/ default

fenics_ice/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, mesh_in, input_data, param_in, init_fields=True,
4646
self.params = param_in
4747
self.input_data = input_data
4848
self.solvers = []
49-
self.parallel = MPI.size(mesh_in.mpi_comm()) > 1
49+
self.parallel = mesh_in.mpi_comm().size > 1
5050

5151
# Generate Domain and Function Spaces
5252
self.mesh = mesh_in

fenics_ice/solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
line_search_rank0_scipy_scalar_search_wolfe1 as line_search_rank0
2424

2525
import logging
26-
import mpi4py.MPI as MPI
26+
import mpi4py.MPI as MPI # noqa: N817
2727
import numpy as np
2828
from pathlib import Path
2929
import time

runs/run_errorprop.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

18-
from fenics_ice.backend import Function, HDF5File, MPI
18+
from fenics_ice.backend import Function, HDF5File
1919

2020
import os
2121
os.environ["OMP_NUM_THREADS"] = "1"
2222
os.environ["OPENBLAS_NUM_THREADS"] = "1"
2323

24+
import mpi4py.MPI as MPI # noqa: N817
2425
from pathlib import Path
2526
import pickle
2627
import numpy as np
@@ -102,7 +103,7 @@ def run_errorprop(config_file):
102103
# and eigenvectors from .h5 file
103104
eps = params.constants.float_eps
104105
W = []
105-
with HDF5File(MPI.comm_world, str(outdir_e/vecfile), 'r') as hdf5data:
106+
with HDF5File(MPI.COMM_WORLD, str(outdir_e/vecfile), 'r') as hdf5data:
106107
for i in range(nlam):
107108
w = Function(space)
108109
hdf5data.read(w, f'v/vector_{i}')
@@ -120,7 +121,7 @@ def run_errorprop(config_file):
120121

121122
# File containing dQoi_dCntrl (i.e. Jacobian of parameter to observable (Qoi))
122123
outdir_qoi = Path(outdir)/phase_time/phase_suffix_qoi
123-
hdf5data = HDF5File(MPI.comm_world, str(outdir_qoi/dqoi_h5file), 'r')
124+
hdf5data = HDF5File(MPI.COMM_WORLD, str(outdir_qoi/dqoi_h5file), 'r')
124125

125126
dQ_cntrl = Function(space, space_type="conjugate_dual")
126127

@@ -181,7 +182,7 @@ def run_errorprop(config_file):
181182
diag_dir = Path(params.io.diagnostics_dir)/phase_err/phase_suffix_err
182183
outdir_err = Path(params.io.output_dir)/phase_err/phase_suffix_err
183184

184-
# if(MPI.comm_world.rank == 0):
185+
# if(MPI.COMM_WORLD.rank == 0):
185186
plt.semilogy(sigma_steps, sigma_conv)
186187
plt.title("Convergence of sigmaQoI")
187188
plt.ylabel("sigma QoI")

runs/run_invsigma.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

1818
from fenics_ice.backend import FiniteElement, Function, FunctionSpace, \
19-
HDF5File, MPI, TestFunction, assemble, assign, inner, dx
19+
HDF5File, TestFunction, assemble, assign, inner, dx
2020

2121
import os
2222
os.environ["OMP_NUM_THREADS"] = "1"
2323
os.environ["OPENBLAS_NUM_THREADS"] = "1"
2424

25+
import mpi4py.MPI as MPI # noqa: N817
2526
from pathlib import Path
2627
import pickle
2728
import numpy as np
@@ -43,8 +44,8 @@ def patch_fun(mesh_in, params):
4344
import random
4445
from scipy.spatial import KDTree
4546

46-
comm = MPI.comm_world
47-
rank = MPI.rank(comm)
47+
comm = MPI.COMM_WORLD
48+
rank = comm.rank
4849
root = rank == 0
4950

5051
# Test DG function
@@ -114,7 +115,7 @@ def patch_fun(mesh_in, params):
114115
def run_invsigma(config_file):
115116
"""Compute control sigma values from eigendecomposition"""
116117

117-
comm = MPI.comm_world
118+
comm = MPI.COMM_WORLD
118119

119120
# Read run config file
120121
params = ConfigParser(config_file)

runs/run_sample.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# You should have received a copy of the GNU Lesser General Public License
1616
# along with tlm_adjoint. If not, see <https://www.gnu.org/licenses/>.
1717

18-
from fenics_ice.backend import Function, HDF5File, MPI, project
18+
from fenics_ice.backend import Function, HDF5File, project
1919

2020
import os
2121
os.environ["OMP_NUM_THREADS"] = "1"
2222
os.environ["OPENBLAS_NUM_THREADS"] = "1"
2323

24+
import mpi4py.MPI as MPI # noqa: N817
2425
import sys
2526
import numpy as np
2627
import pickle
@@ -120,7 +121,7 @@ def run_sample(config_file):
120121
lam = lam[:max_lam]
121122

122123
y = Function(space)
123-
with HDF5File(MPI.comm_world, str(outdir_e/vecfile), 'r') as hdf5data:
124+
with HDF5File(MPI.COMM_WORLD, str(outdir_e/vecfile), 'r') as hdf5data:
124125
for i in range(len(lam)):
125126
w = Function(space)
126127
hdf5data.read(w, f'v/vector_{i}')

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def pytest_configure(config):
129129
pytest.parallel = MPI.COMM_WORLD.size > 1
130130

131131
comm = MPI.COMM_WORLD
132-
rank = comm.Get_rank()
132+
rank = comm.rank
133133

134134
# Clone the input data once even when test are run in parallel
135135
if rank == 0:

tests/test_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import fenics_ice.backend as fe
1919

20-
from mpi4py import MPI
2120
import pytest
2221
import numpy as np
2322
import fenics_ice as fice

tests/test_runs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from runs import run_inv, run_forward, run_eigendec, run_errorprop, run_invsigma
2525
from fenics_ice import config
2626
import shutil
27-
from mpi4py import MPI
27+
2828

2929
def EQReset():
3030
"""Take care of tlm_adjoint EquationManager"""
@@ -396,4 +396,4 @@ def test_run_smith_error_prop(temp_model, monkeypatch):
396396

397397
pytest.check_float_result(Q_sigma,
398398
Q_sigma_expected_at_50_eival,
399-
work_dir, 'expected_Q_sigma')
399+
work_dir, 'expected_Q_sigma')

0 commit comments

Comments
 (0)