diff --git a/.covrignore b/.covrignore index 635cf5df..7e497293 100644 --- a/.covrignore +++ b/.covrignore @@ -7,3 +7,4 @@ src/malaria.cpp src/sir.cpp src/sirode.cpp src/walk.cpp +inst/include/lostturnip.hpp diff --git a/.gitattributes b/.gitattributes index 3614828a..886f2359 100644 --- a/.gitattributes +++ b/.gitattributes @@ -7,3 +7,4 @@ src/sir.cpp linguist-generated=true src/sirode.cpp linguist-generated=true src/walk.cpp linguist-generated=true R/import-*.R linguist-generated=true +inst/include/lostturnip.hpp linguist-vendored=true linguist-generated=true diff --git a/DESCRIPTION b/DESCRIPTION index b205b64a..5976dfcb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: dust2 Title: Next Generation dust -Version: 0.3.10 +Version: 0.3.11 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Imperial College of Science, Technology and Medicine", diff --git a/R/interface.R b/R/interface.R index 9ca7160b..4f65338d 100644 --- a/R/interface.R +++ b/R/interface.R @@ -485,7 +485,10 @@ dust_system_internals <- function(sys, include_coefficients = FALSE, error = vnapply(dat, "[[", "error"), n_steps = viapply(dat, "[[", "n_steps"), n_steps_accepted = viapply(dat, "[[", "n_steps_accepted"), - n_steps_rejected = viapply(dat, "[[", "n_steps_rejected")) + n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"), + events = I(lapply(dat, function(x) { + if (is.null(x$events)) NULL else as.data.frame(x$events) + }))) if (include_coefficients) { ret$coefficients <- I(lapply(dat, "[[", "coefficients")) } diff --git a/inst/include/dust2/common.hpp b/inst/include/dust2/common.hpp index 04c87007..18cb1b41 100644 --- a/inst/include/dust2/common.hpp +++ b/inst/include/dust2/common.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include // In an odd place, we might update that later too diff --git a/inst/include/dust2/continuous/events.hpp b/inst/include/dust2/continuous/events.hpp new file mode 100644 index 00000000..3a13ed85 --- /dev/null +++ b/inst/include/dust2/continuous/events.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +namespace dust2 { +namespace ode { + +// Do we detect an event when passing through the root as we increase +// (negative to positive), decrease (positive to negative) or both: +enum class root_type { + both, + increase, + decrease +}; + +// The actual logic for the above +template +bool is_root(const real_type a, const real_type b, const root_type& root) { + switch(root) { + case root_type::both: + return a * b < 0; + case root_type::increase: + return a < 0 && b > 0; + case root_type::decrease: + return a > 0 && b < 0; + } + return false; +} + +template +struct event { + using test_type = std::function; + using action_type = std::function; + std::vector index; + root_type root = root_type::both; + test_type test; + action_type action; + + event(const std::vector& index, test_type test, action_type action, root_type root = root_type::both) : + index(index), root(root), test(test), action(action) { + } + + event(size_t index, action_type action, root_type root = root_type::both) : + event({index}, [](real_type t, const real_type* y) { return y[0]; }, action, root) { + } +}; + +template +using events_type = std::vector>; + +template +struct event_history_element { + real_type time; + size_t index; + real_type sign; +}; + +template +using event_history = std::vector>; + +} +} diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index e3397adb..07e9d460 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -4,9 +4,12 @@ #include #include #include +#include + #include #include #include +#include #include namespace dust2 { @@ -31,6 +34,7 @@ template struct internals { history_step last; history history_values; + event_history events; std::vector dydt; std::vector step_times; @@ -147,10 +151,12 @@ class solver { // Take a single step template real_type step(real_type t, real_type t_end, real_type* y, + const events_type& events, ode::internals& internals, Rhs rhs) { auto success = false; auto reject = false; auto truncated = false; + auto event = false; auto h = internals.step_size; while (!success) { @@ -178,13 +184,22 @@ class solver { if (err <= 1) { success = true; update_interpolation(t, h, y, internals); + if (!events.empty()) { + const auto t_next = apply_events(t, h, y, events, internals); + if (t_next < t + h) { + event = true; + truncated = false; + h = t_next - t; + rhs(t_next, y_next_.data(), k2_.data()); + } + } accept(t, h, y, internals); internals.n_steps_accepted++; if (control_.debug_record_step_times) { internals.step_times.push_back(truncated ? t_end : t + h); } internals.save_history(); - if (!truncated) { + if (!truncated && !event) { const auto fac_old = std::max(internals.error, static_cast(1e-4)); auto fac = fac11 / std::pow(fac_old, control_.beta); @@ -209,11 +224,12 @@ class solver { template void run(real_type t, real_type t_end, real_type* y, zero_every_type& zero_every, + const events_type& events, ode::internals& internals, Rhs rhs) { if (control_.critical_times.empty()) { while (t < t_end) { apply_zero_every(t, y, zero_every, internals); - t = step(t, t_end, y, internals, rhs); + t = step(t, t_end, y, events, internals, rhs); } } else { // Slightly more complex loop which ensures we never integrate @@ -224,7 +240,7 @@ class solver { auto t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc; while (t < t_end) { apply_zero_every(t, y, zero_every, internals); - t = step(t, t_end_i, y, internals, rhs); + t = step(t, t_end_i, y, events, internals, rhs); if (t >= t_end_i && t < t_end) { ++tc; t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc; @@ -295,8 +311,6 @@ class solver { private: void update_interpolation(real_type t, real_type h, real_type* y, ode::internals& internals) { - // We might want to only do this bit if we'll actually use the - // history, but it's pretty cheap really. internals.last.t0 = t; internals.last.t1 = t + h; internals.last.h = h; @@ -315,6 +329,47 @@ class solver { std::copy_n(y_next_.begin(), n_variables_, y); } + real_type apply_events(real_type t0, real_type h, const real_type* y, + const events_type& events, + ode::internals& internals) { + size_t idx_first = events.size(); + real_type t1 = t0 + h; + real_type sign = 0; + + for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) { + const auto& e = events[idx_event]; + // Use y_stiff as temporary space here, it's only used + // transiently and within the step + real_type * y_t = y_stiff_.data(); + auto fn = [&](auto t) { + internals.last.interpolate(t, e.index, y_t); + return e.test(t, y_t); + }; + const auto f_t0 = fn(t0); + const auto f_t1 = fn(t1); + if (is_root(f_t0, f_t1, e.root)) { + // These probably should move into the ode control, but there + // should really be any great need to change them, and the + // interpolation is expected to be quite fast and accurate. + constexpr real_type eps = 1e-6; + constexpr size_t steps = 100; + auto root = lostturnip::find_result(fn, t0, t1, eps, steps); + idx_first = idx_event; + t1 = root.x; + sign = f_t0 < 0 ? 1 : -1; + } + if (idx_first < events.size()) { + internals.last.interpolate(t1, y_next_.data()); + events[idx_first].action(t1, sign, y_next_.data()); + // We need to modify the history here so that search will find + // the right point. + internals.last.t1 = t1; + internals.events.push_back({t1, idx_first, sign}); + } + } + return t1; + } + void apply_zero_every(real_type t, real_type* y, const zero_every_type& zero_every, ode::internals& internals) { diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index c9021fd3..2120a01c 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -67,6 +68,7 @@ class dust_continuous { errors_(n_particles_total_), rng_(n_particles_total_, seed, deterministic), delays_(do_delays(shared_)), + events_(do_events(shared_, internal_)), solver_(n_groups_ * n_threads_, {n_state_ode_, control_}), output_is_current_(n_groups_), requires_initialise_(n_groups_, true) { @@ -89,7 +91,7 @@ class dust_continuous { const auto offset = k * n_state_; real_type * y = state_data + offset; try { - solver_[i].run(time_, time, y, zero_every_[group], + solver_[i].run(time_, time, y, zero_every_[group], events_[i], ode_internals_[k], rhs_(particle, group, thread)); } catch (std::exception const& e) { @@ -130,7 +132,7 @@ class dust_continuous { for (size_t step = 0; step < n_steps; ++step) { const real_type t0 = t1; t1 = (step == n_steps - 1) ? time : time_ + step * dt_; - solver_[i].run(t0, t1, y, zero_every_[group], + solver_[i].run(t0, t1, y, zero_every_[group], events_[i], ode_internals_[k], rhs_(particle, group, thread)); std::copy_n(y, n_state_ode_, y_other); @@ -389,6 +391,7 @@ class dust_continuous { dust2::utils::errors errors_; monty::random::prng rng_; std::vector> delays_; + std::vector> events_; std::vector> solver_; std::vector output_is_current_; std::vector requires_initialise_; diff --git a/inst/include/dust2/properties.hpp b/inst/include/dust2/properties.hpp index 61eb37f9..c4dbe6f9 100644 --- a/inst/include/dust2/properties.hpp +++ b/inst/include/dust2/properties.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -44,6 +45,11 @@ struct test_has_delays: std::false_type {}; template struct test_has_delays>: std::true_type {}; +template +struct test_has_events: std::false_type {}; +template +struct test_has_events>: std::true_type {}; + // These test that the signature of rhs and output consume the delays // argument. Not especially lovely to read! template @@ -90,6 +96,7 @@ struct properties { // Because of the above these are now actual numbers rather than types; we may make this change everywhere... static constexpr bool rhs_uses_delays = internals::test_rhs_uses_delays(); static constexpr bool output_uses_delays = internals::test_output_uses_delays(); + using has_events = internals::test_has_events; }; // wrappers around some uses of member functions that may or may not @@ -161,4 +168,32 @@ auto do_delays(const std::vector& shared) { return std::vector>(shared.size(), dust2::ode::delays{{}}); } +template ::has_events::value, T>::type* = nullptr> +auto do_events(const std::vector& shared, + std::vector& internal) { + using real_type = typename T::real_type; + std::vector> ret; + + const auto n_groups = shared.size(); + const auto n_threads = internal.size(); + ret.reserve(n_threads * n_groups); + auto iter = internal.begin(); + for (size_t i = 0; i < n_threads; ++i) { + for (auto& s : shared) { + ret.push_back(T::events(s, *iter)); + ++iter; + } + } + return ret; +} + +template ::has_events::value, T>::type* = nullptr> +auto do_events(const std::vector& shared, + std::vector& internal) { + using real_type = typename T::real_type; + const auto len = shared.size() * internal.size(); + const auto empty = dust2::ode::events_type{{}}; + return std::vector>(len, empty); +} + } diff --git a/inst/include/dust2/r/continuous/system.hpp b/inst/include/dust2/r/continuous/system.hpp index 17680651..9ff53981 100644 --- a/inst/include/dust2/r/continuous/system.hpp +++ b/inst/include/dust2/r/continuous/system.hpp @@ -67,7 +67,9 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, "n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted), "n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected), "coefficients"_nm = R_NilValue, - "history"_nm = R_NilValue}; + "history"_nm = R_NilValue, + "events"_nm = R_NilValue}; + if (include_coefficients) { auto r_coef = cpp11::writable::doubles_matrix<>(internals.last.c1.size(), 5); auto coef = REAL(r_coef); @@ -107,6 +109,21 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, "coefficients"_nm = std::move(r_history_coef)}; ret["history"] = cpp11::as_sexp(r_history); } + if (!internals.events.empty()) { + const auto n_events = internals.events.size(); + auto r_event_time = cpp11::writable::doubles(n_events); + auto r_event_index = cpp11::writable::integers(n_events); + auto r_event_sign = cpp11::writable::doubles(n_events); + for (size_t i = 0; i < n_events; ++i) { + r_event_time[i] = internals.events[i].time; + r_event_index[i] = static_cast(internals.events[i].index) + 1; + r_event_sign[i] = internals.events[i].sign; + } + auto r_events = cpp11::writable::list{"time"_nm = std::move(r_event_time), + "index"_nm = std::move(r_event_index), + "sign"_nm = std::move(r_event_sign)}; + ret["events"] = cpp11::as_sexp(r_events); + } return ret; } diff --git a/inst/include/lostturnip.hpp b/inst/include/lostturnip.hpp new file mode 100644 index 00000000..3e1ea5d8 --- /dev/null +++ b/inst/include/lostturnip.hpp @@ -0,0 +1,169 @@ +#pragma once +#include +#include +#include + +namespace lostturnip { + +// Declaring these here, rather than within the find_result, as +// otherwise we get a compiler warning about using experimental cuda +// features. It will be equivalent though, but does require C++14. +namespace { +template +constexpr real_type na = std::numeric_limits::quiet_NaN(); + +template +constexpr real_type eps = std::numeric_limits::epsilon(); +} + +template +struct result { + real_type x; + real_type fx; + int iterations; + bool converged; +}; + +// From zeroin.c, in brent.shar +template +#ifdef __NVCC__ +__host__ __device__ +#endif +result find_result(F f, real_type a, real_type b, + real_type tol, int max_iterations) { + real_type fa = f(a); + real_type fb = f(b); + int iterations = 0; + bool converged = false; + + if (fa == 0) { + b = a; + fb = fa; + converged = true; + } else if (fb == 0) { + converged = true; + } else if (fa * fb > 0) { + // Same sign; can't find root with this: + b = na; + fb = na; + converged = false; + } else { + real_type c = a; + real_type fc = fa; // c = a, f(c) = f(a) + + for (; iterations < max_iterations; ++iterations) { // Main iteration loop + // Distance from the last but one to the last approximation + const real_type prev_step = b - a; + + // Interpolation step is calculated in the form p/q; division + // operations is dlayed until the last moment + real_type p; + real_type q; + + if (std::abs(fc) < std::abs(fb)) { + // Swap data for b to be the best approximation + a = b; + b = c; + c = a; + fa = fb; + fb = fc; + fc = fa; + } + + // Actual tolerance + const real_type tol_act = 2 * eps * std::abs(b) + tol / 2; + // Step at this iteration + real_type new_step = (c - b) / 2; + + if (std::abs(new_step) <= tol_act || fb == 0) { + // Acceptable approximation is found + converged = true; + break; + } + + // increase readability below, avoids many repeated static casts + const real_type one = 1; + + // Decide if the interpolation can be tried + // + // If prev_step was large enough and was in true direction, then + // interpolation can be tried + if (std::abs(prev_step) >= tol_act && std::abs(fa) > std::abs(fb)) { + // interpolation + const real_type cb = c - b; + if (a == c) { + // If we have only two distinct points linear interpolation + // can only be applied + const real_type t1 = fb / fa; + p = cb * t1; + q = one - t1; + } else { + // Quadric inverse interpolation + q = fa / fc; + const real_type t1 = fb / fc; + const real_type t2 = fb / fa; + p = t2 * (cb * q * (q - t1) - (b - a) * (t1 - one)); + q = (q - one) * (t1 - one) * (t2 - one); + } + if (p > 0) { + // p was calculated with the opposite sign; make p positive + // and assign possible minus to q + q = -q; + } else { + p = -p; + } + + // If b + p / q falls in [b, c] and isn't too large it is + // accepted + // + // If p / q is too large then the bissection procedure can + // reduce [b,c] range to more extent + if (p < (static_cast(0.75) * cb * q - std::abs(tol_act * q) / 2) && + p < std::abs(prev_step * q / 2)) { + new_step = p / q; + } + } + + // Adjust the step to be not less than tolerance + if (std::abs(new_step) < tol_act) { + new_step = std::copysign(tol_act, new_step); + } + + // Save the previous approximation + a = b; + fa = fb; + // Do step to a new approximation + b += new_step; + fb = f(b); + if ((fb > 0 && fc > 0) || (fb < 0 && fc < 0)) { + // Adjust c for it to have a sign opposite to that of b + c = a; fc = fa; + } + } + } + +#ifdef __CUDA_ARCH__ + __syncwarp(); +#endif + return result{b, fb, iterations, converged}; +} + +template +#ifdef __NVCC__ +__host__ __device__ +#endif +real_type find(F f, real_type a, real_type b, + real_type tol, int max_iterations) { + const auto result = find_result(f, a, b, tol, max_iterations); + if (!result.converged) { +#ifdef __CUDA_ARCH__ + printf("some error\n"); + __trap(); +#else + throw std::runtime_error("some error"); +#endif + } + return result.x; +} + +} diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp new file mode 100644 index 00000000..378524c2 --- /dev/null +++ b/tests/testthat/examples/event.cpp @@ -0,0 +1,66 @@ +#include + +// [[dust2::class(bounce)]] +// [[dust2::time_type(continuous)]] +// [[dust2::parameter(height, rank = 0)]] +// [[dust2::parameter(velocity, rank = 0)]] +class bounce { +public: + bounce() = delete; + + using real_type = double; + + struct shared_state { + real_type g; + real_type height; + real_type velocity; + real_type damp; + }; + + struct internal_state {}; + + using rng_state_type = monty::random::generator; + + static dust2::packing packing_state(const shared_state& shared) { + return dust2::packing{{"height", {}}, {"velocity", {}}}; + } + + static void initial(real_type time, + const shared_state& shared, + internal_state& internal, + rng_state_type& rng_state, + real_type * state) { + state[0] = shared.height; + state[1] = shared.velocity; + } + + static void rhs(real_type time, + const real_type * state, + const shared_state& shared, + internal_state& internal, + real_type * state_deriv) { + state_deriv[0] = state[1]; + state_deriv[1] = -shared.g; + } + + static shared_state build_shared(cpp11::list pars) { + const real_type g = 9.8; + const real_type height = dust2::r::read_real(pars, "height", 0); + const real_type velocity = dust2::r::read_real(pars, "velocity", 10); + const real_type damp = dust2::r::read_real(pars, "damp", 0.9); + return shared_state{g, height, velocity, damp}; + } + + static void update_shared(cpp11::list pars, shared_state& shared) { + shared.damp = dust2::r::read_real(pars, "damp", shared.damp); + } + + static auto events(const shared_state& shared, internal_state& internal) { + auto action = [&](const double t, const double sign, double* y) { + y[0] = 0; + y[1] = -shared.damp * y[1]; + }; + dust2::ode::event e(0, action); + return dust2::ode::events_type({e}); + } +}; diff --git a/tests/testthat/helper-dust.R b/tests/testthat/helper-dust.R index 28244ec0..af91195f 100644 --- a/tests/testthat/helper-dust.R +++ b/tests/testthat/helper-dust.R @@ -83,3 +83,14 @@ local_sir_generator <- function() { options(dust.testing.local_sir_generator = gen) gen } + + +example_bounce_analytic <- function(t, v0 = 10, damp = 0.9, g = 9.8) { + t0 <- 0 + while (last(t0) < last(t)) { + t0 <- c(t0, last(t0) + 2 * v0 * damp^(length(t0) - 1) / g) + } + i <- findInterval(t, t0) + y <- v0 * damp^(i - 1) * (t - t0[i]) - 0.5 * g * (t - t0[i])^2 + list(y = y, roots = t0[t0 > 0 & t0 < last(t)]) +} diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R new file mode 100644 index 00000000..393f3dd3 --- /dev/null +++ b/tests/testthat/test-zzz-events.R @@ -0,0 +1,30 @@ +test_that("can run system with roots and events", { + gen <- dust_compile("examples/event.cpp", quiet = TRUE, debug = TRUE) + + control <- dust_ode_control(debug_record_step_times = TRUE, save_history = TRUE) + sys <- dust_system_create(gen, ode_control = control) + dust_system_set_state_initial(sys) + + ## Use relatively few points for output here as this exacerbates + ## problems, even though the solution looks silly. + t <- seq(0, 6, length.out = 60) + y <- dust_system_simulate(sys, t) + cmp <- example_bounce_analytic(t) + + info <- dust_system_internals(sys, include_history = TRUE) + + ## Find all roots: + r <- info$events[[1]] + expect_equal(nrow(r), 3) + expect_equal(r$time, cmp$roots, tolerance = 1e-6) + expect_equal(r$index, rep(1L, 3)) + expect_equal(r$sign, rep(-1, 3)) + + ## Stop at all roots: + h <- info$history[[1]] + expect_true(all(r$time %in% h$t0)) + expect_true(all(r$time %in% h$t1)) + + ## Overall solution: + expect_equal(y[1, ], cmp$y, tolerance = 1e-6) +})