Skip to content

Commit

Permalink
template pivoting in the linear algebra
Browse files Browse the repository at this point in the history
now we can do `integrator.linalg_do_pivoting=0` to disable pivoting
  • Loading branch information
zingale committed Jan 27, 2024
1 parent a4d7ab4 commit 59ba531
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 30 deletions.
17 changes: 15 additions & 2 deletions integration/BackwardEuler/be_integrator.H
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,27 @@ int single_step (BurnT& state, BeT& be, const Real dt)

int ierr_linpack;
IArray1D pivot;
dgefa<int_neqs>(be.jac, pivot, ierr_linpack);

if (integrator_rp::linalg_do_pivoting == 1) {
constexpr bool allow_pivot{true};
dgefa<int_neqs, allow_pivot>(be.jac, pivot, ierr_linpack);
} else {
constexpr bool allow_pivot{false};
dgefa<int_neqs, allow_pivot>(be.jac, pivot, ierr_linpack);
}

if (ierr_linpack != 0) {
ierr = IERR_LU_DECOMPOSITION_ERROR;
break;
}

dgesl<int_neqs>(be.jac, pivot, b);
if (integrator_rp::linalg_do_pivoting == 1) {
constexpr bool allow_pivot{true};
dgesl<int_neqs, allow_pivot>(be.jac, pivot, b);
} else {
constexpr bool allow_pivot{false};
dgesl<int_neqs, allow_pivot>(be.jac, pivot, b);
}

// update our current guess for the solution

Expand Down
8 changes: 7 additions & 1 deletion integration/VODE/vode_dvjac.H
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ void dvjac (IArray1D& pivot, int& IERPJ, BurnT& state, DvodeT& vstate)
RHS::dgefa(vstate.jac);
IER = 0;
#else
dgefa<int_neqs>(vstate.jac, pivot, IER);
if (integrator_rp::linalg_do_pivoting == 1) {
constexpr bool allow_pivot{true};
dgefa<int_neqs, allow_pivot>(vstate.jac, pivot, IER);
} else {
constexpr bool allow_pivot{false};
dgefa<int_neqs, allow_pivot>(vstate.jac, pivot, IER);
}
#endif

if (IER != 0) {
Expand Down
8 changes: 7 additions & 1 deletion integration/VODE/vode_dvnlsd.H
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ Real dvnlsd (IArray1D& pivot, int& NFLAG, BurnT& state, DvodeT& vstate)
#ifdef NEW_NETWORK_IMPLEMENTATION
RHS::dgesl(vstate.jac, vstate.y);
#else
dgesl<int_neqs>(vstate.jac, pivot, vstate.y);
if (integrator_rp::linalg_do_pivoting == 1) {
constexpr bool allow_pivot{true};
dgesl<int_neqs, allow_pivot>(vstate.jac, pivot, vstate.y);
} else {
constexpr bool allow_pivot{false};
dgesl<int_neqs, allow_pivot>(vstate.jac, pivot, vstate.y);
}
#endif

if (vstate.RC != 1.0_rt) {
Expand Down
3 changes: 3 additions & 0 deletions integration/_parameters
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ nse_deriv_dt_factor real 0.05

# for NSE update, do we include the weak rate neutrino losses?
nse_include_enu_weak integer 1

# for the linear algebra, do we allow pivoting?
linalg_do_pivoting integer 1
67 changes: 41 additions & 26 deletions util/linpack.H
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <ArrayUtilities.H>

template <int num_eqs>
template <int num_eqs, bool allow_pivot>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)
{
Expand All @@ -17,11 +17,17 @@ void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)
// first solve l * y = b
if (nm1 >= 1) {
for (int k = 1; k <= nm1; ++k) {
int l = pivot(k);
Real t = b(l);
if (l != k) {
b(l) = b(k);
b(k) = t;

Real t{};
if constexpr (allow_pivot) {
int l = pivot(k);
t = b(l);
if (l != k) {
b(l) = b(k);
b(k) = t;
}
} else {
t = b(k);
}

for (int j = k+1; j <= num_eqs; ++j) {
Expand All @@ -45,7 +51,7 @@ void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)



template <int num_eqs>
template <int num_eqs, bool allow_pivot>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void dgefa (RArray2D& a, IArray1D& pivot, int& info)
{
Expand All @@ -68,24 +74,29 @@ void dgefa (RArray2D& a, IArray1D& pivot, int& info)

// find l = pivot index
int l = k;
Real dmax = std::abs(a(k,k));
for (int i = k+1; i <= num_eqs; ++i) {
if (std::abs(a(i,k)) > dmax) {
l = i;
dmax = std::abs(a(i,k));

if constexpr (allow_pivot) {
Real dmax = std::abs(a(k,k));
for (int i = k+1; i <= num_eqs; ++i) {
if (std::abs(a(i,k)) > dmax) {
l = i;
dmax = std::abs(a(i,k));
}
}
}

pivot(k) = l;
pivot(k) = l;
}

// zero pivot implies this column already triangularized
if (a(l,k) != 0.0e0_rt) {

// interchange if necessary
if (l != k) {
t = a(l,k);
a(l,k) = a(k,k);
a(k,k) = t;
if constexpr (allow_pivot) {
// interchange if necessary
if (l != k) {
t = a(l,k);
a(l,k) = a(k,k);
a(k,k) = t;
}
}

// compute multipliers
Expand All @@ -97,26 +108,30 @@ void dgefa (RArray2D& a, IArray1D& pivot, int& info)
// row elimination with column indexing
for (int j = k+1; j <= num_eqs; ++j) {
t = a(l,j);
if (l != k) {
a(l,j) = a(k,j);
a(k,j) = t;

if constexpr (allow_pivot) {
if (l != k) {
a(l,j) = a(k,j);
a(k,j) = t;
}
}

for (int i = k+1; i <= num_eqs; ++i) {
a(i,j) += t * a(i,k);
}
}
}
else {

} else {
info = k;

}

}

}

pivot(num_eqs) = num_eqs;
if constexpr (allow_pivot) {
pivot(num_eqs) = num_eqs;
}

if (a(num_eqs,num_eqs) == 0.0e0_rt) {
info = num_eqs;
Expand Down

0 comments on commit 59ba531

Please sign in to comment.