Skip to content

Commit

Permalink
Anticicipate saving subset of history
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Oct 28, 2024
1 parent cddfcfe commit 951227c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 13 deletions.
20 changes: 19 additions & 1 deletion inst/include/dust2/continuous/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct internals {
size_t n_steps;
size_t n_steps_accepted;
size_t n_steps_rejected;
std::vector<size_t> history_index;

internals(size_t n_variables) :
dydt(n_variables),
Expand All @@ -65,8 +66,25 @@ struct internals {
reset();
}

void set_history_index(const std::vector<size_t>& index) {
if (tools::is_trivial_index(index, dydt.size())) {
history_index.clear();
} else {
history_index = index;
}
}

void save_history(real_type t, real_type h) {
history_values.push_back({t, h, c1, c2, c3, c4, c5});
if (history_index.empty()) {
history_values.push_back({t, h, c1, c2, c3, c4, c5});
} else {
history_values.push_back({t, h,
tools::subset(c1, history_index),
tools::subset(c2, history_index),
tools::subset(c3, history_index),
tools::subset(c4, history_index),
tools::subset(c5, history_index)});
}
}

void reset() {
Expand Down
8 changes: 2 additions & 6 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class dust_continuous {

shared_(shared),
internal_(internal),
all_groups_(n_groups_),
all_groups_(tools::integer_sequence(n_groups_)),

time_(time),
zero_every_(zero_every_vec<T>(shared_)),
Expand All @@ -77,9 +77,6 @@ class dust_continuous {
// We don't check that the size is the same across all states;
// this should be done by the caller (similarly, we don't check
// that shared and internal have the same size).
for (size_t i = 0; i < n_groups_; ++i) {
all_groups_[i] = i;
}
}

template <typename mixed_time = typename dust2::properties<T>::is_mixed_time>
Expand Down Expand Up @@ -178,14 +175,13 @@ class dust_continuous {
try {
T::initial(time_, shared_[i], internal_i,
rng_.state(k), y);
solver_.initialise(time_, y, ode_internals_[k],
rhs_(shared_[i], internal_i));
} catch (std::exception const& e) {
errors_.capture(e, k);
}
}
}
errors_.report();
initialise_solver_(index_group);
// Assume not current, because most models would want to call output here()
update_output_is_current(index_group, false);
}
Expand Down
5 changes: 1 addition & 4 deletions inst/include/dust2/discrete/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class dust_discrete {
state_next_(n_state_ * n_particles_total_),
shared_(shared),
internal_(internal),
all_groups_(n_groups_),
all_groups_(tools::integer_sequence(n_groups_)),
time_(time),
dt_(dt),
zero_every_(zero_every_vec<T>(shared_)),
Expand All @@ -55,9 +55,6 @@ class dust_discrete {
// We don't check that the size is the same across all states;
// this should be done by the caller (similarly, we don't check
// that shared and internal have the same size).
for (size_t i = 0; i < n_groups_; ++i) {
all_groups_[i] = i;
}
}

auto run_to_time(real_type time, const std::vector<size_t>& index_group) {
Expand Down
3 changes: 1 addition & 2 deletions inst/include/dust2/r/continuous/control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ dust2::ode::control<real_type> validate_ode_control(cpp11::list r_time_control)
const auto step_size_max = dust2::r::read_real(ode_control, "step_size_max");
const bool debug_record_step_times =
dust2::r::read_bool(ode_control, "debug_record_step_times");
const bool save_history =
dust2::r::read_bool(ode_control, "save_history");
const auto save_history = dust2::r::read_bool(ode_control, "save_history");
return dust2::ode::control<real_type>(max_steps, atol, rtol, step_size_min,
step_size_max, save_history,
debug_record_step_times);
Expand Down
19 changes: 19 additions & 0 deletions inst/include/dust2/tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ T prod(const std::vector<T>& x) {
return std::accumulate(x.begin(), x.end(), 1, std::multiplies<>{});
}

inline std::vector<size_t> integer_sequence(size_t n) {
std::vector<size_t> ret;
ret.reserve(n);
for (size_t i = 0; i < n; ++i) {
ret.push_back(i);
}
return ret;
}

template <typename T>
std::vector<T> subset(const std::vector<T>& x, const std::vector<size_t> index) {
std::vector<T> ret;
ret.reserve(index.size());
for (auto i : index) {
ret.push_back(x[i]);
}
return ret;
}

inline bool is_trivial_index(const std::vector<size_t>& index, size_t n) {
if (index.empty()) {
return true;
Expand Down

0 comments on commit 951227c

Please sign in to comment.