From aff25a1abc067a70420db1b534e0427a2f141b1c Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 25 Oct 2024 17:50:12 +0100 Subject: [PATCH 1/4] Save history --- R/cpp11.R | 12 +++---- R/interface-continuous.R | 21 +++++++----- R/interface.R | 10 ++++-- inst/include/dust2/continuous/control.hpp | 3 ++ inst/include/dust2/continuous/solver.hpp | 22 +++++++++++- inst/include/dust2/r/continuous/control.hpp | 7 ++-- inst/include/dust2/r/continuous/system.hpp | 38 ++++++++++++++++++--- inst/template/continuous/system.cpp | 4 +-- src/cpp11.cpp | 24 ++++++------- src/logistic.cpp | 4 +-- src/malaria.cpp | 4 +-- src/sirode.cpp | 4 +-- tests/testthat/test-logistic.R | 23 +++++++++++++ 13 files changed, 133 insertions(+), 43 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index 56e58b92..72d1c449 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -4,8 +4,8 @@ dust2_system_logistic_alloc <- function(r_pars, r_time, r_time_control, r_n_part .Call(`_dust2_dust2_system_logistic_alloc`, r_pars, r_time, r_time_control, r_n_particles, r_n_groups, r_seed, r_deterministic, r_n_threads) } -dust2_system_logistic_internals <- function(ptr, include_coefficients) { - .Call(`_dust2_dust2_system_logistic_internals`, ptr, include_coefficients) +dust2_system_logistic_internals <- function(ptr, include_coefficients, include_history) { + .Call(`_dust2_dust2_system_logistic_internals`, ptr, include_coefficients, include_history) } dust2_system_logistic_run_to_time <- function(ptr, r_time) { @@ -56,8 +56,8 @@ dust2_system_malaria_alloc <- function(r_pars, r_time, r_time_control, r_n_parti .Call(`_dust2_dust2_system_malaria_alloc`, r_pars, r_time, r_time_control, r_n_particles, r_n_groups, r_seed, r_deterministic, r_n_threads) } -dust2_system_malaria_internals <- function(ptr, include_coefficients) { - .Call(`_dust2_dust2_system_malaria_internals`, ptr, include_coefficients) +dust2_system_malaria_internals <- function(ptr, include_coefficients, include_history) { + .Call(`_dust2_dust2_system_malaria_internals`, ptr, include_coefficients, include_history) } dust2_system_malaria_run_to_time <- function(ptr, r_time) { @@ -268,8 +268,8 @@ dust2_system_sirode_alloc <- function(r_pars, r_time, r_time_control, r_n_partic .Call(`_dust2_dust2_system_sirode_alloc`, r_pars, r_time, r_time_control, r_n_particles, r_n_groups, r_seed, r_deterministic, r_n_threads) } -dust2_system_sirode_internals <- function(ptr, include_coefficients) { - .Call(`_dust2_dust2_system_sirode_internals`, ptr, include_coefficients) +dust2_system_sirode_internals <- function(ptr, include_coefficients, include_history) { + .Call(`_dust2_dust2_system_sirode_internals`, ptr, include_coefficients, include_history) } dust2_system_sirode_run_to_time <- function(ptr, r_time) { diff --git a/R/interface-continuous.R b/R/interface-continuous.R index 68278f8e..98aa1e4b 100644 --- a/R/interface-continuous.R +++ b/R/interface-continuous.R @@ -32,27 +32,32 @@ ##' debugging. Step times can be retrieved via ##' [dust_system_internals()]. ##' +##' @param save_history Optional vector of states for which we should +##' save index. This is intended only for debugging. +##' ##' @export ##' ##' @return A named list of class "dust_ode_control". Do not modify ##' this after creation. dust_ode_control <- function(max_steps = 10000, atol = 1e-6, rtol = 1e-6, step_size_min = 0, step_size_max = Inf, - debug_record_step_times = FALSE) { - call <- environment() + debug_record_step_times = FALSE, + save_history = NULL) { ctl <- list( max_steps = assert_scalar_size( - max_steps, allow_zero = FALSE, call = call), + max_steps, allow_zero = FALSE), atol = assert_scalar_positive_numeric( - atol, allow_zero = FALSE, call = call), + atol, allow_zero = FALSE), rtol = assert_scalar_positive_numeric( - rtol, allow_zero = FALSE, call = call), + rtol, allow_zero = FALSE), step_size_min = assert_scalar_positive_numeric( - step_size_min, allow_zero = TRUE, call = call), + step_size_min, allow_zero = TRUE), step_size_max = assert_scalar_positive_numeric( - step_size_max, allow_zero = TRUE, call = call), + step_size_max, allow_zero = TRUE), + save_history = assert_scalar_logical( + save_history), debug_record_step_times = assert_scalar_logical( - debug_record_step_times, call = call)) + debug_record_step_times)) class(ctl) <- "dust_ode_control" ctl } diff --git a/R/interface.R b/R/interface.R index 8f3daf36..e6fb277c 100644 --- a/R/interface.R +++ b/R/interface.R @@ -448,6 +448,8 @@ dust_system_reorder <- function(sys, index) { ##' coefficients should be included in the output. These are ##' intentionally undocumented for now. ##' +##' @param include_history Boolean, also undocumented. +##' ##' @return If `sys` is a discrete-time system, this function returns ##' `NULL`, as no internal data is stored. Otherwise, for a ##' continuous-time system we return a `data.frame` of statistics @@ -462,13 +464,14 @@ dust_system_reorder <- function(sys, index) { ##' (the structure of these may change over time, too). ##' ##' @export -dust_system_internals <- function(sys, include_coefficients = FALSE) { +dust_system_internals <- function(sys, include_coefficients = FALSE, + include_history = FALSE) { check_is_dust_system(sys) if (sys$properties$time_type == "discrete") { ## No internals for now, perhaps never? return(NULL) } - dat <- sys$methods$internals(sys$ptr, include_coefficients) + dat <- sys$methods$internals(sys$ptr, include_coefficients, include_history) ret <- data_frame( particle = seq_along(dat), dydt = I(lapply(dat, "[[", "dydt")), @@ -481,6 +484,9 @@ dust_system_internals <- function(sys, include_coefficients = FALSE) { if (include_coefficients) { ret$coefficients <- I(lapply(dat, "[[", "coefficients")) } + if (include_history) { + ret$history <- I(lapply(dat, "[[", "history")) + } ret } diff --git a/inst/include/dust2/continuous/control.hpp b/inst/include/dust2/continuous/control.hpp index d4cdbb6a..59df4472 100644 --- a/inst/include/dust2/continuous/control.hpp +++ b/inst/include/dust2/continuous/control.hpp @@ -18,14 +18,17 @@ struct control { real_type factor_max = 10.0; // from dopri5.f:281, retard.f:333 real_type beta = 0.04; real_type constant = 0.2 - 0.04 * 0.75; // 0.04 is beta + bool save_history = false; bool debug_record_step_times = false; control(size_t max_steps, real_type atol, real_type rtol, real_type step_size_min, real_type step_size_max, + bool save_history, bool debug_record_step_times) : max_steps(max_steps), atol(atol), rtol(rtol), step_size_min(step_size_min), step_size_max(step_size_max), + save_history(save_history), debug_record_step_times(debug_record_step_times) {} control() {} diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 17de9c82..51a5ce1d 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -20,6 +20,17 @@ T clamp(T x, T min, T max) { return std::max(std::min(x, max), min); } +template +struct history { + real_type t; + real_type h; + std::vector c1; + std::vector c2; + std::vector c3; + std::vector c4; + std::vector c5; +}; + // This is the internal state separate from 'y' that defines the // system. This includes the derivatives. // @@ -29,6 +40,7 @@ template struct internals { std::vector dydt; std::vector step_times; + std::vector> history_values; // Interpolation coefficients std::vector c1; std::vector c2; @@ -53,6 +65,10 @@ struct internals { reset(); } + void save_history(real_type t, real_type h) { + history_values.push_back({t, h, c1, c2, c3, c4, c5}); + } + void reset() { last_step_size = 0; step_size = 0; @@ -60,7 +76,8 @@ struct internals { n_steps = 0; n_steps_accepted = 0; n_steps_rejected = 0; - step_times.resize(0); + step_times.clear(); + history_values.clear(); } }; @@ -175,6 +192,9 @@ class solver { if (control_.debug_record_step_times) { internals.step_times.push_back(truncated ? t_end : t + h); } + if (control_.save_history) { + internals.save_history(t, h); + } if (!truncated) { const auto fac_old = std::max(internals.error, static_cast(1e-4)); diff --git a/inst/include/dust2/r/continuous/control.hpp b/inst/include/dust2/r/continuous/control.hpp index cc7e83a5..cffff75f 100644 --- a/inst/include/dust2/r/continuous/control.hpp +++ b/inst/include/dust2/r/continuous/control.hpp @@ -16,10 +16,13 @@ dust2::ode::control validate_ode_control(cpp11::list r_time_control) std::max(static_cast(dust2::r::read_real(ode_control, "step_size_min")), std::numeric_limits::epsilon()); const auto step_size_max = dust2::r::read_real(ode_control, "step_size_max"); - const auto debug_record_step_times = + const bool debug_record_step_times = dust2::r::read_bool(ode_control, "debug_record_step_times"); + const bool save_history = + dust2::r::read_bool(ode_control, "save_history"); return dust2::ode::control(max_steps, atol, rtol, step_size_min, - step_size_max, debug_record_step_times); + step_size_max, save_history, + debug_record_step_times); } } diff --git a/inst/include/dust2/r/continuous/system.hpp b/inst/include/dust2/r/continuous/system.hpp index 3b651ef6..8172764c 100644 --- a/inst/include/dust2/r/continuous/system.hpp +++ b/inst/include/dust2/r/continuous/system.hpp @@ -55,7 +55,8 @@ SEXP dust2_continuous_alloc(cpp11::list r_pars, template cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, - bool include_coefficients) { + bool include_coefficients, + bool include_history) { using namespace cpp11::literals; auto ret = cpp11::writable::list{ "dydt"_nm = cpp11::as_sexp(internals.dydt), @@ -65,7 +66,8 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, "n_steps"_nm = cpp11::as_sexp(internals.n_steps), "n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted), "n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected), - "coefficients"_nm = R_NilValue}; + "coefficients"_nm = R_NilValue, + "history"_nm = R_NilValue}; if (include_coefficients) { auto r_coef = cpp11::writable::doubles_matrix<>(internals.c1.size(), 5); auto coef = REAL(r_coef); @@ -76,18 +78,46 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, coef = std::copy(internals.c5.begin(), internals.c5.end(), coef); ret["coefficients"] = r_coef; } + if (include_history && !internals.history_values.empty()) { + const auto n_history_entries = internals.history_values.size(); + const auto n_history_state = internals.history_values[0].c1.size(); + auto r_history_coef = + cpp11::writable::doubles(n_history_state * 5 * n_history_entries); + auto r_history_time = cpp11::writable::doubles(n_history_entries); + auto r_history_size = cpp11::writable::doubles(n_history_entries); + auto history_coef = REAL(r_history_coef); + auto history_time = REAL(r_history_time); + auto history_size = REAL(r_history_size); + for (auto& h: internals.history_values) { + *history_time++ = h.t; + *history_size++ = h.h; + history_coef = std::copy(h.c1.begin(), h.c1.end(), history_coef); + history_coef = std::copy(h.c2.begin(), h.c2.end(), history_coef); + history_coef = std::copy(h.c3.begin(), h.c3.end(), history_coef); + history_coef = std::copy(h.c4.begin(), h.c4.end(), history_coef); + history_coef = std::copy(h.c5.begin(), h.c5.end(), history_coef); + } + set_array_dims(r_history_coef, {n_history_state, 5, n_history_entries}); + auto r_history = cpp11::writable::list{"time"_nm = r_history_time, + "size"_nm = r_history_size, + "coefficients"_nm = r_history_coef}; + ret["history"] = cpp11::as_sexp(r_history); + } return ret; } template -SEXP dust2_system_internals(cpp11::sexp ptr, bool include_coefficients) { +SEXP dust2_system_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) { auto *obj = cpp11::as_cpp>(ptr).get(); const auto& internals = obj->ode_internals(); cpp11::writable::list ret(internals.size()); for (size_t i = 0; i < internals.size(); ++i) { - ret[i] = ode_internals_to_sexp(internals[i], include_coefficients); + ret[i] = ode_internals_to_sexp(internals[i], + include_coefficients, + include_history); } return cpp11::as_sexp(ret); + } } diff --git a/inst/template/continuous/system.cpp b/inst/template/continuous/system.cpp index e40d4e38..c76fc4b7 100644 --- a/inst/template/continuous/system.cpp +++ b/inst/template/continuous/system.cpp @@ -4,6 +4,6 @@ SEXP dust2_system_{{name}}_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11:: } [[cpp11::register]] -SEXP dust2_system_{{name}}_internals(cpp11::sexp ptr, bool include_coefficients) { - return dust2::r::dust2_system_internals>(ptr, include_coefficients); +SEXP dust2_system_{{name}}_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) { + return dust2::r::dust2_system_internals>(ptr, include_coefficients, include_history); } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 41a9c25d..7779a77e 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -13,10 +13,10 @@ extern "C" SEXP _dust2_dust2_system_logistic_alloc(SEXP r_pars, SEXP r_time, SEX END_CPP11 } // logistic.cpp -SEXP dust2_system_logistic_internals(cpp11::sexp ptr, bool include_coefficients); -extern "C" SEXP _dust2_dust2_system_logistic_internals(SEXP ptr, SEXP include_coefficients) { +SEXP dust2_system_logistic_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history); +extern "C" SEXP _dust2_dust2_system_logistic_internals(SEXP ptr, SEXP include_coefficients, SEXP include_history) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_logistic_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients))); + return cpp11::as_sexp(dust2_system_logistic_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients), cpp11::as_cpp>(include_history))); END_CPP11 } // logistic.cpp @@ -104,10 +104,10 @@ extern "C" SEXP _dust2_dust2_system_malaria_alloc(SEXP r_pars, SEXP r_time, SEXP END_CPP11 } // malaria.cpp -SEXP dust2_system_malaria_internals(cpp11::sexp ptr, bool include_coefficients); -extern "C" SEXP _dust2_dust2_system_malaria_internals(SEXP ptr, SEXP include_coefficients) { +SEXP dust2_system_malaria_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history); +extern "C" SEXP _dust2_dust2_system_malaria_internals(SEXP ptr, SEXP include_coefficients, SEXP include_history) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_malaria_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients))); + return cpp11::as_sexp(dust2_system_malaria_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients), cpp11::as_cpp>(include_history))); END_CPP11 } // malaria.cpp @@ -475,10 +475,10 @@ extern "C" SEXP _dust2_dust2_system_sirode_alloc(SEXP r_pars, SEXP r_time, SEXP END_CPP11 } // sirode.cpp -SEXP dust2_system_sirode_internals(cpp11::sexp ptr, bool include_coefficients); -extern "C" SEXP _dust2_dust2_system_sirode_internals(SEXP ptr, SEXP include_coefficients) { +SEXP dust2_system_sirode_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history); +extern "C" SEXP _dust2_dust2_system_sirode_internals(SEXP ptr, SEXP include_coefficients, SEXP include_history) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_system_sirode_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients))); + return cpp11::as_sexp(dust2_system_sirode_internals(cpp11::as_cpp>(ptr), cpp11::as_cpp>(include_coefficients), cpp11::as_cpp>(include_history))); END_CPP11 } // sirode.cpp @@ -857,7 +857,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_filter_sirode_set_rng_state", (DL_FUNC) &_dust2_dust2_filter_sirode_set_rng_state, 2}, {"_dust2_dust2_filter_sirode_update_pars", (DL_FUNC) &_dust2_dust2_filter_sirode_update_pars, 3}, {"_dust2_dust2_system_logistic_alloc", (DL_FUNC) &_dust2_dust2_system_logistic_alloc, 8}, - {"_dust2_dust2_system_logistic_internals", (DL_FUNC) &_dust2_dust2_system_logistic_internals, 2}, + {"_dust2_dust2_system_logistic_internals", (DL_FUNC) &_dust2_dust2_system_logistic_internals, 3}, {"_dust2_dust2_system_logistic_reorder", (DL_FUNC) &_dust2_dust2_system_logistic_reorder, 2}, {"_dust2_dust2_system_logistic_rng_state", (DL_FUNC) &_dust2_dust2_system_logistic_rng_state, 1}, {"_dust2_dust2_system_logistic_run_to_time", (DL_FUNC) &_dust2_dust2_system_logistic_run_to_time, 2}, @@ -871,7 +871,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_logistic_update_pars", (DL_FUNC) &_dust2_dust2_system_logistic_update_pars, 2}, {"_dust2_dust2_system_malaria_alloc", (DL_FUNC) &_dust2_dust2_system_malaria_alloc, 8}, {"_dust2_dust2_system_malaria_compare_data", (DL_FUNC) &_dust2_dust2_system_malaria_compare_data, 4}, - {"_dust2_dust2_system_malaria_internals", (DL_FUNC) &_dust2_dust2_system_malaria_internals, 2}, + {"_dust2_dust2_system_malaria_internals", (DL_FUNC) &_dust2_dust2_system_malaria_internals, 3}, {"_dust2_dust2_system_malaria_reorder", (DL_FUNC) &_dust2_dust2_system_malaria_reorder, 2}, {"_dust2_dust2_system_malaria_rng_state", (DL_FUNC) &_dust2_dust2_system_malaria_rng_state, 1}, {"_dust2_dust2_system_malaria_run_to_time", (DL_FUNC) &_dust2_dust2_system_malaria_run_to_time, 2}, @@ -898,7 +898,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_system_sir_update_pars", (DL_FUNC) &_dust2_dust2_system_sir_update_pars, 2}, {"_dust2_dust2_system_sirode_alloc", (DL_FUNC) &_dust2_dust2_system_sirode_alloc, 8}, {"_dust2_dust2_system_sirode_compare_data", (DL_FUNC) &_dust2_dust2_system_sirode_compare_data, 4}, - {"_dust2_dust2_system_sirode_internals", (DL_FUNC) &_dust2_dust2_system_sirode_internals, 2}, + {"_dust2_dust2_system_sirode_internals", (DL_FUNC) &_dust2_dust2_system_sirode_internals, 3}, {"_dust2_dust2_system_sirode_reorder", (DL_FUNC) &_dust2_dust2_system_sirode_reorder, 2}, {"_dust2_dust2_system_sirode_rng_state", (DL_FUNC) &_dust2_dust2_system_sirode_rng_state, 1}, {"_dust2_dust2_system_sirode_run_to_time", (DL_FUNC) &_dust2_dust2_system_sirode_run_to_time, 2}, diff --git a/src/logistic.cpp b/src/logistic.cpp index 46e7c18b..6a99f99f 100644 --- a/src/logistic.cpp +++ b/src/logistic.cpp @@ -90,8 +90,8 @@ SEXP dust2_system_logistic_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11:: } [[cpp11::register]] -SEXP dust2_system_logistic_internals(cpp11::sexp ptr, bool include_coefficients) { - return dust2::r::dust2_system_internals>(ptr, include_coefficients); +SEXP dust2_system_logistic_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) { + return dust2::r::dust2_system_internals>(ptr, include_coefficients, include_history); } [[cpp11::register]] SEXP dust2_system_logistic_run_to_time(cpp11::sexp ptr, cpp11::sexp r_time) { diff --git a/src/malaria.cpp b/src/malaria.cpp index 734fee77..dbd9aa11 100644 --- a/src/malaria.cpp +++ b/src/malaria.cpp @@ -156,8 +156,8 @@ SEXP dust2_system_malaria_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::l } [[cpp11::register]] -SEXP dust2_system_malaria_internals(cpp11::sexp ptr, bool include_coefficients) { - return dust2::r::dust2_system_internals>(ptr, include_coefficients); +SEXP dust2_system_malaria_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) { + return dust2::r::dust2_system_internals>(ptr, include_coefficients, include_history); } [[cpp11::register]] SEXP dust2_system_malaria_run_to_time(cpp11::sexp ptr, cpp11::sexp r_time) { diff --git a/src/sirode.cpp b/src/sirode.cpp index 09d45281..31f8d3aa 100644 --- a/src/sirode.cpp +++ b/src/sirode.cpp @@ -118,8 +118,8 @@ SEXP dust2_system_sirode_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::li } [[cpp11::register]] -SEXP dust2_system_sirode_internals(cpp11::sexp ptr, bool include_coefficients) { - return dust2::r::dust2_system_internals>(ptr, include_coefficients); +SEXP dust2_system_sirode_internals(cpp11::sexp ptr, bool include_coefficients, bool include_history) { + return dust2::r::dust2_system_internals>(ptr, include_coefficients, include_history); } [[cpp11::register]] SEXP dust2_system_sirode_run_to_time(cpp11::sexp ptr, cpp11::sexp r_time) { diff --git a/tests/testthat/test-logistic.R b/tests/testthat/test-logistic.R index 29af6424..c47df00a 100644 --- a/tests/testthat/test-logistic.R +++ b/tests/testthat/test-logistic.R @@ -345,3 +345,26 @@ test_that("can save step times for debugging", { expect_equal(d$step_times[[1]][[d$n_steps + 1]], 10) expect_false(any(diff(d$step_times[[1]]) <= 0)) }) + + + + +test_that("can save history", { + pars <- list(n = 3, r = c(0.1, 0.2, 0.3), K = rep(100, 3)) + ctl <- dust_ode_control(save_history = TRUE, debug_record_step_times = TRUE) + sys <- dust_system_create(logistic(), pars, n_particles = 1, + preserve_particle_dimension = TRUE, + deterministic = TRUE, ode_control = ctl) + dust_system_set_state_initial(sys) + dust_system_run_to_time(sys, 10) + d <- dust_system_internals(sys, + include_coefficients = TRUE, + include_history = TRUE) + expect_true("history" %in% names(d)) + history <- d$history[[1]] + n_steps <- d$n_steps_accepted + expect_length(history$time, n_steps) + expect_equal(history$time, d$step_times[[1]][-(n_steps + 1)]) + expect_equal(history$size, diff(d$step_times[[1]])) + expect_equal(history$coefficients[, , n_steps], d$coefficients[[1]]) +}) From cddfcfe4befc3583db1458f32ab935a5237a59f1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 28 Oct 2024 09:03:17 +0000 Subject: [PATCH 2/4] Fix interface --- R/interface-continuous.R | 8 +++++--- man/dust_ode_control.Rd | 8 +++++++- man/dust_system_internals.Rd | 8 +++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/R/interface-continuous.R b/R/interface-continuous.R index 98aa1e4b..f045f262 100644 --- a/R/interface-continuous.R +++ b/R/interface-continuous.R @@ -32,8 +32,10 @@ ##' debugging. Step times can be retrieved via ##' [dust_system_internals()]. ##' -##' @param save_history Optional vector of states for which we should -##' save index. This is intended only for debugging. +##' @param save_history Logical, indicating if we should save history +##' during running. This should only be enabled for debugging. +##' Data can be retrieved via [dust_system_internals()], but the +##' format is undocumented. ##' ##' @export ##' @@ -42,7 +44,7 @@ dust_ode_control <- function(max_steps = 10000, atol = 1e-6, rtol = 1e-6, step_size_min = 0, step_size_max = Inf, debug_record_step_times = FALSE, - save_history = NULL) { + save_history = FALSE) { ctl <- list( max_steps = assert_scalar_size( max_steps, allow_zero = FALSE), diff --git a/man/dust_ode_control.Rd b/man/dust_ode_control.Rd index 4e2884e7..977d3277 100644 --- a/man/dust_ode_control.Rd +++ b/man/dust_ode_control.Rd @@ -10,7 +10,8 @@ dust_ode_control( rtol = 1e-06, step_size_min = 0, step_size_max = Inf, - debug_record_step_times = FALSE + debug_record_step_times = FALSE, + save_history = FALSE ) } \arguments{ @@ -40,6 +41,11 @@ a smaller maximum step size here.} times should be recorded. This should only be enabled for debugging. Step times can be retrieved via \code{\link[=dust_system_internals]{dust_system_internals()}}.} + +\item{save_history}{Logical, indicating if we should save history +during running. This should only be enabled for debugging. +Data can be retrieved via \code{\link[=dust_system_internals]{dust_system_internals()}}, but the +format is undocumented.} } \value{ A named list of class "dust_ode_control". Do not modify diff --git a/man/dust_system_internals.Rd b/man/dust_system_internals.Rd index ec92bb66..957b6a07 100644 --- a/man/dust_system_internals.Rd +++ b/man/dust_system_internals.Rd @@ -4,7 +4,11 @@ \alias{dust_system_internals} \title{Fetch system internals} \usage{ -dust_system_internals(sys, include_coefficients = FALSE) +dust_system_internals( + sys, + include_coefficients = FALSE, + include_history = FALSE +) } \arguments{ \item{sys}{A \code{dust_system} object} @@ -12,6 +16,8 @@ dust_system_internals(sys, include_coefficients = FALSE) \item{include_coefficients}{Boolean, indicating if interpolation coefficients should be included in the output. These are intentionally undocumented for now.} + +\item{include_history}{Boolean, also undocumented.} } \value{ If \code{sys} is a discrete-time system, this function returns From 951227c3004e6dcbb148423cbf25a2def9f112b5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 28 Oct 2024 10:18:45 +0000 Subject: [PATCH 3/4] Anticicipate saving subset of history --- inst/include/dust2/continuous/solver.hpp | 20 +++++++++++++++++++- inst/include/dust2/continuous/system.hpp | 8 ++------ inst/include/dust2/discrete/system.hpp | 5 +---- inst/include/dust2/r/continuous/control.hpp | 3 +-- inst/include/dust2/tools.hpp | 19 +++++++++++++++++++ 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 51a5ce1d..e6661c2f 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -53,6 +53,7 @@ struct internals { size_t n_steps; size_t n_steps_accepted; size_t n_steps_rejected; + std::vector history_index; internals(size_t n_variables) : dydt(n_variables), @@ -65,8 +66,25 @@ struct internals { reset(); } + void set_history_index(const std::vector& index) { + if (tools::is_trivial_index(index, dydt.size())) { + history_index.clear(); + } else { + history_index = index; + } + } + void save_history(real_type t, real_type h) { - history_values.push_back({t, h, c1, c2, c3, c4, c5}); + if (history_index.empty()) { + history_values.push_back({t, h, c1, c2, c3, c4, c5}); + } else { + history_values.push_back({t, h, + tools::subset(c1, history_index), + tools::subset(c2, history_index), + tools::subset(c3, history_index), + tools::subset(c4, history_index), + tools::subset(c5, history_index)}); + } } void reset() { diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index ded0c488..fdc91d05 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -61,7 +61,7 @@ class dust_continuous { shared_(shared), internal_(internal), - all_groups_(n_groups_), + all_groups_(tools::integer_sequence(n_groups_)), time_(time), zero_every_(zero_every_vec(shared_)), @@ -77,9 +77,6 @@ class dust_continuous { // We don't check that the size is the same across all states; // this should be done by the caller (similarly, we don't check // that shared and internal have the same size). - for (size_t i = 0; i < n_groups_; ++i) { - all_groups_[i] = i; - } } template ::is_mixed_time> @@ -178,14 +175,13 @@ class dust_continuous { try { T::initial(time_, shared_[i], internal_i, rng_.state(k), y); - solver_.initialise(time_, y, ode_internals_[k], - rhs_(shared_[i], internal_i)); } catch (std::exception const& e) { errors_.capture(e, k); } } } errors_.report(); + initialise_solver_(index_group); // Assume not current, because most models would want to call output here() update_output_is_current(index_group, false); } diff --git a/inst/include/dust2/discrete/system.hpp b/inst/include/dust2/discrete/system.hpp index c17d3273..32c37279 100644 --- a/inst/include/dust2/discrete/system.hpp +++ b/inst/include/dust2/discrete/system.hpp @@ -45,7 +45,7 @@ class dust_discrete { state_next_(n_state_ * n_particles_total_), shared_(shared), internal_(internal), - all_groups_(n_groups_), + all_groups_(tools::integer_sequence(n_groups_)), time_(time), dt_(dt), zero_every_(zero_every_vec(shared_)), @@ -55,9 +55,6 @@ class dust_discrete { // We don't check that the size is the same across all states; // this should be done by the caller (similarly, we don't check // that shared and internal have the same size). - for (size_t i = 0; i < n_groups_; ++i) { - all_groups_[i] = i; - } } auto run_to_time(real_type time, const std::vector& index_group) { diff --git a/inst/include/dust2/r/continuous/control.hpp b/inst/include/dust2/r/continuous/control.hpp index cffff75f..8dc553d2 100644 --- a/inst/include/dust2/r/continuous/control.hpp +++ b/inst/include/dust2/r/continuous/control.hpp @@ -18,8 +18,7 @@ dust2::ode::control validate_ode_control(cpp11::list r_time_control) const auto step_size_max = dust2::r::read_real(ode_control, "step_size_max"); const bool debug_record_step_times = dust2::r::read_bool(ode_control, "debug_record_step_times"); - const bool save_history = - dust2::r::read_bool(ode_control, "save_history"); + const auto save_history = dust2::r::read_bool(ode_control, "save_history"); return dust2::ode::control(max_steps, atol, rtol, step_size_min, step_size_max, save_history, debug_record_step_times); diff --git a/inst/include/dust2/tools.hpp b/inst/include/dust2/tools.hpp index 3a2f2deb..75b16381 100644 --- a/inst/include/dust2/tools.hpp +++ b/inst/include/dust2/tools.hpp @@ -52,6 +52,25 @@ T prod(const std::vector& x) { return std::accumulate(x.begin(), x.end(), 1, std::multiplies<>{}); } +inline std::vector integer_sequence(size_t n) { + std::vector ret; + ret.reserve(n); + for (size_t i = 0; i < n; ++i) { + ret.push_back(i); + } + return ret; +} + +template +std::vector subset(const std::vector& x, const std::vector index) { + std::vector ret; + ret.reserve(index.size()); + for (auto i : index) { + ret.push_back(x[i]); + } + return ret; +} + inline bool is_trivial_index(const std::vector& index, size_t n) { if (index.empty()) { return true; From a85451ba6674e5386dfed6bb07b79a1bb28fb87e Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 28 Oct 2024 12:26:48 +0000 Subject: [PATCH 4/4] Add header --- inst/include/dust2/continuous/solver.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index e6661c2f..85a9fce2 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include