Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save history in ode systems #109

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions R/cpp11.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 15 additions & 8 deletions R/interface-continuous.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,34 @@
##' debugging. Step times can be retrieved via
##' [dust_system_internals()].
##'
##' @param save_history Logical, indicating if we should save history
##' during running. This should only be enabled for debugging.
##' Data can be retrieved via [dust_system_internals()], but the
##' format is undocumented.
##'
##' @export
##'
##' @return A named list of class "dust_ode_control". Do not modify
##' this after creation.
dust_ode_control <- function(max_steps = 10000, atol = 1e-6, rtol = 1e-6,
step_size_min = 0, step_size_max = Inf,
debug_record_step_times = FALSE) {
call <- environment()
debug_record_step_times = FALSE,
save_history = FALSE) {
ctl <- list(
max_steps = assert_scalar_size(
max_steps, allow_zero = FALSE, call = call),
max_steps, allow_zero = FALSE),
atol = assert_scalar_positive_numeric(
atol, allow_zero = FALSE, call = call),
atol, allow_zero = FALSE),
rtol = assert_scalar_positive_numeric(
rtol, allow_zero = FALSE, call = call),
rtol, allow_zero = FALSE),
step_size_min = assert_scalar_positive_numeric(
step_size_min, allow_zero = TRUE, call = call),
step_size_min, allow_zero = TRUE),
step_size_max = assert_scalar_positive_numeric(
step_size_max, allow_zero = TRUE, call = call),
step_size_max, allow_zero = TRUE),
save_history = assert_scalar_logical(
save_history),
debug_record_step_times = assert_scalar_logical(
debug_record_step_times, call = call))
debug_record_step_times))
class(ctl) <- "dust_ode_control"
ctl
}
10 changes: 8 additions & 2 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ dust_system_reorder <- function(sys, index) {
##' coefficients should be included in the output. These are
##' intentionally undocumented for now.
##'
##' @param include_history Boolean, also undocumented.
##'
##' @return If `sys` is a discrete-time system, this function returns
##' `NULL`, as no internal data is stored. Otherwise, for a
##' continuous-time system we return a `data.frame` of statistics
Expand All @@ -462,13 +464,14 @@ dust_system_reorder <- function(sys, index) {
##' (the structure of these may change over time, too).
##'
##' @export
dust_system_internals <- function(sys, include_coefficients = FALSE) {
dust_system_internals <- function(sys, include_coefficients = FALSE,
include_history = FALSE) {
check_is_dust_system(sys)
if (sys$properties$time_type == "discrete") {
## No internals for now, perhaps never?
return(NULL)
}
dat <- sys$methods$internals(sys$ptr, include_coefficients)
dat <- sys$methods$internals(sys$ptr, include_coefficients, include_history)
ret <- data_frame(
particle = seq_along(dat),
dydt = I(lapply(dat, "[[", "dydt")),
Expand All @@ -481,6 +484,9 @@ dust_system_internals <- function(sys, include_coefficients = FALSE) {
if (include_coefficients) {
ret$coefficients <- I(lapply(dat, "[[", "coefficients"))
}
if (include_history) {
ret$history <- I(lapply(dat, "[[", "history"))
}
ret
}

Expand Down
3 changes: 3 additions & 0 deletions inst/include/dust2/continuous/control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ struct control {
real_type factor_max = 10.0; // from dopri5.f:281, retard.f:333
real_type beta = 0.04;
real_type constant = 0.2 - 0.04 * 0.75; // 0.04 is beta
bool save_history = false;
bool debug_record_step_times = false;

control(size_t max_steps, real_type atol, real_type rtol,
real_type step_size_min, real_type step_size_max,
bool save_history,
bool debug_record_step_times) :
max_steps(max_steps), atol(atol), rtol(rtol),
step_size_min(step_size_min),
step_size_max(step_size_max),
save_history(save_history),
debug_record_step_times(debug_record_step_times) {}

control() {}
Expand Down
41 changes: 40 additions & 1 deletion inst/include/dust2/continuous/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cmath>
#include <stdexcept>
#include <vector>
#include <dust2/tools.hpp>
#include <dust2/zero.hpp>
#include <dust2/continuous/control.hpp>

Expand All @@ -20,6 +21,17 @@ T clamp(T x, T min, T max) {
return std::max(std::min(x, max), min);
}

template <typename real_type>
struct history {
real_type t;
real_type h;
std::vector<real_type> c1;
std::vector<real_type> c2;
std::vector<real_type> c3;
std::vector<real_type> c4;
std::vector<real_type> c5;
};

// This is the internal state separate from 'y' that defines the
// system. This includes the derivatives.
//
Expand All @@ -29,6 +41,7 @@ template <typename real_type>
struct internals {
std::vector<real_type> dydt;
std::vector<real_type> step_times;
std::vector<history<real_type>> history_values;
// Interpolation coefficients
std::vector<real_type> c1;
std::vector<real_type> c2;
Expand All @@ -41,6 +54,7 @@ struct internals {
size_t n_steps;
size_t n_steps_accepted;
size_t n_steps_rejected;
std::vector<size_t> history_index;

internals(size_t n_variables) :
dydt(n_variables),
Expand All @@ -53,14 +67,36 @@ struct internals {
reset();
}

void set_history_index(const std::vector<size_t>& index) {
if (tools::is_trivial_index(index, dydt.size())) {
history_index.clear();
} else {
history_index = index;
}
}

void save_history(real_type t, real_type h) {
if (history_index.empty()) {
history_values.push_back({t, h, c1, c2, c3, c4, c5});
} else {
history_values.push_back({t, h,
tools::subset(c1, history_index),
tools::subset(c2, history_index),
tools::subset(c3, history_index),
tools::subset(c4, history_index),
tools::subset(c5, history_index)});
}
}

void reset() {
last_step_size = 0;
step_size = 0;
error = 0;
n_steps = 0;
n_steps_accepted = 0;
n_steps_rejected = 0;
step_times.resize(0);
step_times.clear();
history_values.clear();
}
};

Expand Down Expand Up @@ -175,6 +211,9 @@ class solver {
if (control_.debug_record_step_times) {
internals.step_times.push_back(truncated ? t_end : t + h);
}
if (control_.save_history) {
internals.save_history(t, h);
}
if (!truncated) {
const auto fac_old =
std::max(internals.error, static_cast<real_type>(1e-4));
Expand Down
8 changes: 2 additions & 6 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class dust_continuous {

shared_(shared),
internal_(internal),
all_groups_(n_groups_),
all_groups_(tools::integer_sequence(n_groups_)),

time_(time),
zero_every_(zero_every_vec<T>(shared_)),
Expand All @@ -77,9 +77,6 @@ class dust_continuous {
// We don't check that the size is the same across all states;
// this should be done by the caller (similarly, we don't check
// that shared and internal have the same size).
for (size_t i = 0; i < n_groups_; ++i) {
all_groups_[i] = i;
}
}

template <typename mixed_time = typename dust2::properties<T>::is_mixed_time>
Expand Down Expand Up @@ -178,14 +175,13 @@ class dust_continuous {
try {
T::initial(time_, shared_[i], internal_i,
rng_.state(k), y);
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
initialise_solver_(index_group);
// Assume not current, because most models would want to call output here()
update_output_is_current(index_group, false);
}
Expand Down
5 changes: 1 addition & 4 deletions inst/include/dust2/discrete/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class dust_discrete {
state_next_(n_state_ * n_particles_total_),
shared_(shared),
internal_(internal),
all_groups_(n_groups_),
all_groups_(tools::integer_sequence(n_groups_)),
time_(time),
dt_(dt),
zero_every_(zero_every_vec<T>(shared_)),
Expand All @@ -55,9 +55,6 @@ class dust_discrete {
// We don't check that the size is the same across all states;
// this should be done by the caller (similarly, we don't check
// that shared and internal have the same size).
for (size_t i = 0; i < n_groups_; ++i) {
all_groups_[i] = i;
}
}

auto run_to_time(real_type time, const std::vector<size_t>& index_group) {
Expand Down
6 changes: 4 additions & 2 deletions inst/include/dust2/r/continuous/control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ dust2::ode::control<real_type> validate_ode_control(cpp11::list r_time_control)
std::max(static_cast<real_type>(dust2::r::read_real(ode_control, "step_size_min")),
std::numeric_limits<real_type>::epsilon());
const auto step_size_max = dust2::r::read_real(ode_control, "step_size_max");
const auto debug_record_step_times =
const bool debug_record_step_times =
dust2::r::read_bool(ode_control, "debug_record_step_times");
const auto save_history = dust2::r::read_bool(ode_control, "save_history");
return dust2::ode::control<real_type>(max_steps, atol, rtol, step_size_min,
step_size_max, debug_record_step_times);
step_size_max, save_history,
debug_record_step_times);
}

}
Expand Down
38 changes: 34 additions & 4 deletions inst/include/dust2/r/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ SEXP dust2_continuous_alloc(cpp11::list r_pars,

template <typename real_type>
cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
bool include_coefficients) {
bool include_coefficients,
bool include_history) {
using namespace cpp11::literals;
auto ret = cpp11::writable::list{
"dydt"_nm = cpp11::as_sexp(internals.dydt),
Expand All @@ -65,7 +66,8 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
"n_steps"_nm = cpp11::as_sexp(internals.n_steps),
"n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted),
"n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected),
"coefficients"_nm = R_NilValue};
"coefficients"_nm = R_NilValue,
"history"_nm = R_NilValue};
if (include_coefficients) {
auto r_coef = cpp11::writable::doubles_matrix<>(internals.c1.size(), 5);
auto coef = REAL(r_coef);
Expand All @@ -76,18 +78,46 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
coef = std::copy(internals.c5.begin(), internals.c5.end(), coef);
ret["coefficients"] = r_coef;
}
if (include_history && !internals.history_values.empty()) {
const auto n_history_entries = internals.history_values.size();
const auto n_history_state = internals.history_values[0].c1.size();
auto r_history_coef =
cpp11::writable::doubles(n_history_state * 5 * n_history_entries);
auto r_history_time = cpp11::writable::doubles(n_history_entries);
auto r_history_size = cpp11::writable::doubles(n_history_entries);
auto history_coef = REAL(r_history_coef);
auto history_time = REAL(r_history_time);
auto history_size = REAL(r_history_size);
for (auto& h: internals.history_values) {
*history_time++ = h.t;
*history_size++ = h.h;
history_coef = std::copy(h.c1.begin(), h.c1.end(), history_coef);
history_coef = std::copy(h.c2.begin(), h.c2.end(), history_coef);
history_coef = std::copy(h.c3.begin(), h.c3.end(), history_coef);
history_coef = std::copy(h.c4.begin(), h.c4.end(), history_coef);
history_coef = std::copy(h.c5.begin(), h.c5.end(), history_coef);
}
set_array_dims(r_history_coef, {n_history_state, 5, n_history_entries});
auto r_history = cpp11::writable::list{"time"_nm = r_history_time,
"size"_nm = r_history_size,
"coefficients"_nm = r_history_coef};
ret["history"] = cpp11::as_sexp(r_history);
}
return ret;
}

template <typename T>
SEXP dust2_system_internals(cpp11::sexp ptr, bool include_coefficients) {
SEXP dust2_system_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<T>>(ptr).get();
const auto& internals = obj->ode_internals();
cpp11::writable::list ret(internals.size());
for (size_t i = 0; i < internals.size(); ++i) {
ret[i] = ode_internals_to_sexp(internals[i], include_coefficients);
ret[i] = ode_internals_to_sexp(internals[i],
include_coefficients,
include_history);
}
return cpp11::as_sexp(ret);

}

}
Expand Down
19 changes: 19 additions & 0 deletions inst/include/dust2/tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ T prod(const std::vector<T>& x) {
return std::accumulate(x.begin(), x.end(), 1, std::multiplies<>{});
}

inline std::vector<size_t> integer_sequence(size_t n) {
std::vector<size_t> ret;
ret.reserve(n);
for (size_t i = 0; i < n; ++i) {
ret.push_back(i);
}
return ret;
}

template <typename T>
std::vector<T> subset(const std::vector<T>& x, const std::vector<size_t> index) {
std::vector<T> ret;
ret.reserve(index.size());
for (auto i : index) {
ret.push_back(x[i]);
}
return ret;
}

inline bool is_trivial_index(const std::vector<size_t>& index, size_t n) {
if (index.empty()) {
return true;
Expand Down
4 changes: 2 additions & 2 deletions inst/template/continuous/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SEXP dust2_system_{{name}}_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::
}

[[cpp11::register]]
SEXP dust2_system_{{name}}_internals(cpp11::sexp ptr, bool include_coefficients) {
return dust2::r::dust2_system_internals<dust2::dust_{{time_type}}<{{class}}>>(ptr, include_coefficients);
SEXP dust2_system_{{name}}_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) {
return dust2::r::dust2_system_internals<dust2::dust_{{time_type}}<{{class}}>>(ptr, include_coefficients, include_history);
}
Loading
Loading