Skip to content

Commit

Permalink
MultiFab: Static to Member Math Methods (#301)
Browse files Browse the repository at this point in the history
Replace `static` member functions that with regular member functions
that take store the result/destination in `self`. (Assume `dst` is
`self`.)

- [x] `FabArray<FArrayBox>`
- [x] `MultiFab`

Close #296
  • Loading branch information
ax3l authored Jan 29, 2025
1 parent 8e46613 commit 1f91a4d
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 105 deletions.
266 changes: 177 additions & 89 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <AMReX_FabArrayBase.H>
#include <AMReX_FabFactory.H>
#include <AMReX_MultiFab.H>
#include <AMReX_iMultiFab.H>

#include <memory>
#include <string>
Expand Down Expand Up @@ -215,23 +216,69 @@ void init_MultiFab(py::module &m)
py::arg("comp"), py::arg("ncomp"), py::arg("nghost")
)

.def_static("saxpy",
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template Saxpy<FArrayBox>),
py::arg("y"), py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("ycomp"), py::arg("ncomp"), py::arg("nghost"),
"y += a*x"
)
.def_static("xpay",
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template Xpay<FArrayBox>),
py::arg("y"), py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("ycomp"), py::arg("ncomp"), py::arg("nghost"),
"y = x + a*y"
)
.def_static("lin_comb",
py::overload_cast< FabArray<FArrayBox> &, Real, FabArray<FArrayBox> const &, int, Real, FabArray<FArrayBox> const &, int, int, int, IntVect const & >(&FabArray<FArrayBox>::template LinComb<FArrayBox>),
py::arg("dst"),
.def("saxpy",
[](FabArray<FArrayBox> & dst, Real a, FabArray<FArrayBox> const & x, int x_comp, int comp, int ncomp, IntVect const & nghost)
{
FabArray<FArrayBox>::Saxpy(dst, a, x, x_comp, comp, ncomp, nghost);
},
py::arg("a"), py::arg("x"), py::arg("x_comp"), py::arg("comp"), py::arg("ncomp"), py::arg("nghost"),
"self += a * x\n\n"
"Parameters\n"
"----------\n"
"a : scalar a\n"
"x : FabArray x\n"
"x_comp : starting component of x\n"
"comp : starting component of self\n"
"ncomp : number of components\n"
"nghost : number of ghost cells"
)
.def("xpay",
[](FabArray<FArrayBox> & self, Real a, FabArray<FArrayBox> const & x, int x_comp, int comp, int ncomp, IntVect const & nghost)
{
FabArray<FArrayBox>::Xpay(self, a, x, x_comp, comp, ncomp, nghost);
},
py::arg("a"), py::arg("x"), py::arg("xcomp"), py::arg("comp"), py::arg("ncomp"), py::arg("nghost"),
"self = x + a * self\n\n"
"Parameters\n"
"----------\n"
"a : scalar a\n"
"x : FabArray x\n"
"x_comp : starting component of x\n"
"comp : starting component of self\n"
"ncomp : number of components\n"
"nghost : number of ghost cells"
)
.def("lin_comb",
[](
FabArray<FArrayBox> & dst,
Real a, FabArray<FArrayBox> const & x, int x_comp,
Real b, FabArray<FArrayBox> const & y, int y_comp,
int comp, int ncomp, IntVect const & nghost)
{
FabArray<FArrayBox>::LinComb(dst, a, x, x_comp, b, y, y_comp, comp, ncomp, nghost);
},
py::arg("a"), py::arg("x"), py::arg("xcomp"),
py::arg("b"), py::arg("y"), py::arg("ycomp"),
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"dst = a*x + b*y"
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self = a * x + b * y\n\n"
"Parameters\n"
"----------\n"
"a : float\n"
" scalar a\n"
"x : FabArray\n"
"xcomp : int\n"
" starting component of x\n"
"b : float\n"
" scalar b\n"
"y : FabArray\n"
"ycomp : int\n"
" starting component of y\n"
"comp : int\n"
" starting component of self\n"
"numcomp : int\n"
" number of components\n"
"nghost : int\n"
" number of ghost cells"
)

.def("sum",
Expand Down Expand Up @@ -730,123 +777,164 @@ void init_MultiFab(py::module &m)
)

/* static (standalone) simple math functions */
.def_static("dot",
py::overload_cast< MultiFab const &, int, MultiFab const &, int, int, int, bool >(&MultiFab::Dot),
py::arg("x"), py::arg("xcomp"),
py::arg("y"), py::arg("ycomp"),
.def("dot",
[](MultiFab const & self, int comp, MultiFab const & y, int y_comp, int numcomp, int nghost, bool local) {
return MultiFab::Dot(self, comp, y, y_comp, numcomp, nghost, local);
},
py::arg("comp"),
py::arg("y"), py::arg("y_comp"),
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
"Returns the dot product of two MultiFabs."
"Returns the dot product of self with another MultiFab."
)
.def_static("dot",
py::overload_cast< MultiFab const &, int, int, int, bool >(&MultiFab::Dot),
py::arg("x"), py::arg("xcomp"),
.def("dot",
[](MultiFab const & self, int comp, int numcomp, int nghost, bool local) {
return MultiFab::Dot(self, comp, numcomp, nghost, local);
},
py::arg("comp"),
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
"Returns the dot product with itself."
)
/** TODO: Bind iMultiFab
.def("dot",
[](MultiFab const& self, const iMultiFab& mask, int comp, MultiFab const& y, int y_comp, int numcomp, int nghost, bool local) {
return MultiFab::Dot(mask, self, comp, y, y_comp, numcomp, nghost, local);
},
py::arg("mask"), py::arg("comp"), py::arg("y"), py::arg("y_comp"),
py::arg("numcomp"), py::arg("nghost"), py::arg("local")=false,
"Returns the dot product of a MultiFab with itself."
"Returns the dot product of self with another MultiFab where the mask is valid."
)
//.def_static("dot", py::overload_cast< iMultiFab const&, const MultiFab&, int, MultiFab const&, int, int, int, bool >(&MultiFab::Dot))
*/

.def_static("add",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Add),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Add src to dst including nghost ghost cells.\n"
.def("add",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Add(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Add src to self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)
.def_static("add",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Add),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Add src to dst including nghost ghost cells.\n"
.def("add",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
MultiFab::Add(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Add src to self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)

.def_static("subtract",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Subtract),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Subtract src from dst including nghost ghost cells.\n"
.def("subtract",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Subtract(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Subtract src from self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)
.def_static("subtract",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Subtract),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Subtract src from dst including nghost ghost cells.\n"
.def("subtract",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
MultiFab::Subtract(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Subtract src from self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)

.def_static("multiply",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Multiply),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Multiply dst by src including nghost ghost cells.\n"
.def("multiply",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Multiply(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Multiply self by src including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)
.def_static("multiply",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Multiply),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Multiply dst by src including nghost ghost cells.\n"
.def("multiply",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
MultiFab::Multiply(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Multiply self by src including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)

.def_static("divide",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, int >(&MultiFab::Divide),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Divide dst by src including nghost ghost cells.\n"
.def("divide",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Divide(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Divide self by src including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)
.def_static("divide",
py::overload_cast< MultiFab &, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::Divide),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Divide dst by src including nghost ghost cells.\n"
.def("divide",
[](MultiFab & self, MultiFab const & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
MultiFab::Divide(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Divide self by src including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray."
)

.def_static("swap",
py::overload_cast< MultiFab &, MultiFab &, int, int, int, int >(&MultiFab::Swap),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Swap from src to dst including nghost ghost cells.\n"
.def("swap",
[](MultiFab & self, MultiFab & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Swap(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Swap from src to self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray.\n"
"The swap is local."
)
.def_static("swap",
py::overload_cast< MultiFab &, MultiFab &, int, int, int, IntVect const & >(&MultiFab::Swap),
py::arg("dst"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"Swap from src to dst including nghost ghost cells.\n"
.def("swap",
[](MultiFab & self, MultiFab & src, int srccomp, int comp, int numcomp, IntVect const & nghost) {
MultiFab::Swap(self, src, srccomp, comp, numcomp, nghost);
},
py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"Swap from src to self including nghost ghost cells.\n"
"The two MultiFabs MUST have the same underlying BoxArray.\n"
"The swap is local."
)

.def_static("saxpy",
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, int, int, int >(&MultiFab::Saxpy)
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, int, int, int)>(&MultiFab::Saxpy),
py::arg("dst"), py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"dst += a*src"
.def("saxpy",
[](MultiFab & self, Real a, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Saxpy(self, a, src, srccomp, comp, numcomp, nghost);
},
py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self += a * src"
)

.def_static("xpay",
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, int, int, int >(&MultiFab::Xpay)
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, int, int, int)>(&MultiFab::Xpay),
py::arg("dst"), py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"dst = src + a*dst"
.def("xpay",
[](MultiFab & self, Real a, MultiFab const & src, int srccomp, int comp, int numcomp, int nghost) {
MultiFab::Xpay(self, a, src, srccomp, comp, numcomp, nghost);
},
py::arg("a"), py::arg("src"), py::arg("srccomp"), py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self = src + a * self"
)

.def_static("lin_comb",
// py::overload_cast< MultiFab &, Real, MultiFab const &, int, Real, MultiFab const &, int, int, int, int >(&MultiFab::LinComb)
static_cast<void (*)(MultiFab &, Real, MultiFab const &, int, Real, MultiFab const &, int, int, int, int)>(&MultiFab::LinComb),
py::arg("dst"),
.def("lin_comb",
[](MultiFab & self, Real a, MultiFab const & x, int x_comp, Real b, MultiFab const & y, int y_comp, int comp, int numcomp, int nghost) {
MultiFab::LinComb(self, a, x, x_comp, b, y, y_comp, comp, numcomp, nghost);
},
py::arg("a"), py::arg("x"), py::arg("x_comp"),
py::arg("b"), py::arg("y"), py::arg("y_comp"),
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"dst = a*x + b*y"
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self = a * x + b * y"
)

.def_static("add_product",
py::overload_cast< MultiFab &, MultiFab const &, int, MultiFab const &, int, int, int, int >(&MultiFab::AddProduct),
py::arg("dst"),
.def("add_product",
[](MultiFab & self, MultiFab const & src1, int comp1, MultiFab const & src2, int comp2, int comp, int numcomp, int nghost) {
MultiFab::AddProduct(self, src1, comp1, src2, comp2, comp, numcomp, nghost);
},
py::arg("src1"), py::arg("comp1"),
py::arg("src2"), py::arg("comp2"),
py::arg("dstcomp"), py::arg("numcomp"), py::arg("nghost"),
"dst += src1*src2"
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self += src1 * src2"
)
.def_static("add_product",
py::overload_cast< MultiFab &, MultiFab const &, int, MultiFab const &, int, int, int, IntVect const & >(&MultiFab::AddProduct),
"dst += src1*src2"
.def("add_product",
[](MultiFab & self, MultiFab const & src1, int comp1, MultiFab const & src2, int comp2, int comp, int numcomp, IntVect const & nghost) {
MultiFab::AddProduct(self, src1, comp1, src2, comp2, comp, numcomp, nghost);
},
py::arg("src1"), py::arg("comp1"),
py::arg("src2"), py::arg("comp2"),
py::arg("comp"), py::arg("numcomp"), py::arg("nghost"),
"self += src1 * src2"
)

/* simple data validity checks */
Expand Down
23 changes: 7 additions & 16 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,30 +240,21 @@ def test_mfab_ops(boxarr, distmap, nghost):
src.set_val(30.0, 2, 1)
dst.set_val(0.0, 0, 1)

# dst.add(src, 2, 0, 1, nghost)
# dst.subtract(src, 1, 0, 1, nghost)
# dst.multiply(src, 0, 0, 1, nghost)
# dst.divide(src, 1, 0, 1, nghost)

dst.add(dst, src, 2, 0, 1, nghost)
dst.subtract(dst, src, 1, 0, 1, nghost)
dst.multiply(dst, src, 0, 0, 1, nghost)
dst.divide(dst, src, 1, 0, 1, nghost)
dst.add(src, 2, 0, 1, nghost)
dst.subtract(src, 1, 0, 1, nghost)
dst.multiply(src, 0, 0, 1, nghost)
dst.divide(src, 1, 0, 1, nghost)

print(dst.min(0))
np.testing.assert_allclose(dst.min(0), 5.0)
np.testing.assert_allclose(dst.max(0), 5.0)

# dst.xpay(2.0, src, 0, 0, 1, nghost)
# dst.saxpy(2.0, src, 1, 0, 1, nghost)
dst.xpay(dst, 2.0, src, 0, 0, 1, nghost)
dst.saxpy(dst, 2.0, src, 1, 0, 1, nghost)
dst.xpay(2.0, src, 0, 0, 1, nghost)
dst.saxpy(2.0, src, 1, 0, 1, nghost)
np.testing.assert_allclose(dst.min(0), 60.0)
np.testing.assert_allclose(dst.max(0), 60.0)

# dst.lin_comb(6.0, src, 1,
# 1.0, src, 2, 0, 1, nghost)
dst.lin_comb(dst, 6.0, src, 1, 1.0, src, 2, 0, 1, nghost)
dst.lin_comb(6.0, src, 1, 1.0, src, 2, 0, 1, nghost)
np.testing.assert_allclose(dst.min(0), 150.0)
np.testing.assert_allclose(dst.max(0), 150.0)

Expand Down

0 comments on commit 1f91a4d

Please sign in to comment.