From 951227c3004e6dcbb148423cbf25a2def9f112b5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 28 Oct 2024 10:18:45 +0000 Subject: [PATCH] Anticicipate saving subset of history --- inst/include/dust2/continuous/solver.hpp | 20 +++++++++++++++++++- inst/include/dust2/continuous/system.hpp | 8 ++------ inst/include/dust2/discrete/system.hpp | 5 +---- inst/include/dust2/r/continuous/control.hpp | 3 +-- inst/include/dust2/tools.hpp | 19 +++++++++++++++++++ 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 51a5ce1d..e6661c2f 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -53,6 +53,7 @@ struct internals { size_t n_steps; size_t n_steps_accepted; size_t n_steps_rejected; + std::vector history_index; internals(size_t n_variables) : dydt(n_variables), @@ -65,8 +66,25 @@ struct internals { reset(); } + void set_history_index(const std::vector& index) { + if (tools::is_trivial_index(index, dydt.size())) { + history_index.clear(); + } else { + history_index = index; + } + } + void save_history(real_type t, real_type h) { - history_values.push_back({t, h, c1, c2, c3, c4, c5}); + if (history_index.empty()) { + history_values.push_back({t, h, c1, c2, c3, c4, c5}); + } else { + history_values.push_back({t, h, + tools::subset(c1, history_index), + tools::subset(c2, history_index), + tools::subset(c3, history_index), + tools::subset(c4, history_index), + tools::subset(c5, history_index)}); + } } void reset() { diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index ded0c488..fdc91d05 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -61,7 +61,7 @@ class dust_continuous { shared_(shared), internal_(internal), - all_groups_(n_groups_), + all_groups_(tools::integer_sequence(n_groups_)), time_(time), zero_every_(zero_every_vec(shared_)), @@ -77,9 +77,6 @@ class dust_continuous { // We don't check that the size is the same across all states; // this should be done by the caller (similarly, we don't check // that shared and internal have the same size). - for (size_t i = 0; i < n_groups_; ++i) { - all_groups_[i] = i; - } } template ::is_mixed_time> @@ -178,14 +175,13 @@ class dust_continuous { try { T::initial(time_, shared_[i], internal_i, rng_.state(k), y); - solver_.initialise(time_, y, ode_internals_[k], - rhs_(shared_[i], internal_i)); } catch (std::exception const& e) { errors_.capture(e, k); } } } errors_.report(); + initialise_solver_(index_group); // Assume not current, because most models would want to call output here() update_output_is_current(index_group, false); } diff --git a/inst/include/dust2/discrete/system.hpp b/inst/include/dust2/discrete/system.hpp index c17d3273..32c37279 100644 --- a/inst/include/dust2/discrete/system.hpp +++ b/inst/include/dust2/discrete/system.hpp @@ -45,7 +45,7 @@ class dust_discrete { state_next_(n_state_ * n_particles_total_), shared_(shared), internal_(internal), - all_groups_(n_groups_), + all_groups_(tools::integer_sequence(n_groups_)), time_(time), dt_(dt), zero_every_(zero_every_vec(shared_)), @@ -55,9 +55,6 @@ class dust_discrete { // We don't check that the size is the same across all states; // this should be done by the caller (similarly, we don't check // that shared and internal have the same size). - for (size_t i = 0; i < n_groups_; ++i) { - all_groups_[i] = i; - } } auto run_to_time(real_type time, const std::vector& index_group) { diff --git a/inst/include/dust2/r/continuous/control.hpp b/inst/include/dust2/r/continuous/control.hpp index cffff75f..8dc553d2 100644 --- a/inst/include/dust2/r/continuous/control.hpp +++ b/inst/include/dust2/r/continuous/control.hpp @@ -18,8 +18,7 @@ dust2::ode::control validate_ode_control(cpp11::list r_time_control) const auto step_size_max = dust2::r::read_real(ode_control, "step_size_max"); const bool debug_record_step_times = dust2::r::read_bool(ode_control, "debug_record_step_times"); - const bool save_history = - dust2::r::read_bool(ode_control, "save_history"); + const auto save_history = dust2::r::read_bool(ode_control, "save_history"); return dust2::ode::control(max_steps, atol, rtol, step_size_min, step_size_max, save_history, debug_record_step_times); diff --git a/inst/include/dust2/tools.hpp b/inst/include/dust2/tools.hpp index 3a2f2deb..75b16381 100644 --- a/inst/include/dust2/tools.hpp +++ b/inst/include/dust2/tools.hpp @@ -52,6 +52,25 @@ T prod(const std::vector& x) { return std::accumulate(x.begin(), x.end(), 1, std::multiplies<>{}); } +inline std::vector integer_sequence(size_t n) { + std::vector ret; + ret.reserve(n); + for (size_t i = 0; i < n; ++i) { + ret.push_back(i); + } + return ret; +} + +template +std::vector subset(const std::vector& x, const std::vector index) { + std::vector ret; + ret.reserve(index.size()); + for (auto i : index) { + ret.push_back(x[i]); + } + return ret; +} + inline bool is_trivial_index(const std::vector& index, size_t n) { if (index.empty()) { return true;