From 9562ada038113c613614a94c45b5a958e8110a52 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 14 Oct 2024 16:34:54 +0100 Subject: [PATCH 1/2] Fix crash using internal Adds integration test too --- inst/include/dust2/properties.hpp | 11 +++-- tests/testthat/examples/internal.cpp | 66 ++++++++++++++++++++++++++++ tests/testthat/test-zzz-slow.R | 15 +++++++ 3 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 tests/testthat/examples/internal.cpp diff --git a/inst/include/dust2/properties.hpp b/inst/include/dust2/properties.hpp index 981e27b6..a2fc223e 100644 --- a/inst/include/dust2/properties.hpp +++ b/inst/include/dust2/properties.hpp @@ -8,6 +8,11 @@ namespace dust2 { namespace internals { +template +struct test_has_build_internal: std::false_type {}; +template +struct test_has_build_internal>: std::true_type {}; + template struct test_has_packing_gradient: std::false_type {}; template @@ -32,7 +37,7 @@ struct test_has_rhs>: std::true_type {}; template struct properties { - using has_internal_state = std::is_same, std::false_type>; + using has_build_internal = internals::test_has_build_internal; using has_packing_gradient = internals::test_has_packing_gradient; using has_zero_every = internals::test_has_zero_every; using is_mixed_time = typename std::conditional::value && internals::test_has_update::value, std::true_type, std::false_type>::type; @@ -40,12 +45,12 @@ struct properties { // wrappers around some uses of member functions that may or may not // exist, centralising most of the weird into this file: -template ::has_internal_state::value, T>::type* = nullptr> +template ::has_build_internal::value, T>::type* = nullptr> typename T::internal_state do_build_internal(const typename T::shared_state &shared) { return T::build_internal(shared); } -template ::has_internal_state::value, T>::type* = nullptr> +template ::has_build_internal::value, T>::type* = nullptr> typename T::internal_state do_build_internal(const typename T::shared_state &shared) { return typename T::internal_state{}; } diff --git a/tests/testthat/examples/internal.cpp b/tests/testthat/examples/internal.cpp new file mode 100644 index 00000000..b18df146 --- /dev/null +++ b/tests/testthat/examples/internal.cpp @@ -0,0 +1,66 @@ +#include + +// [[dust2::class(walk)]] +// [[dust2::time_type(discrete)]] +// [[dust2::parameter(sd)]] +// [[dust2::parameter(len)]] +class walk { +public: + walk() = delete; + + using real_type = double; + using rng_state_type = monty::random::generator; + + struct shared_state { + size_t len; + real_type sd; + }; + + struct internal_state { + std::vector scratch; + }; + + static dust2::packing packing_state(const shared_state& shared) { + return dust2::packing{{"x", {}}}; + } + + static shared_state build_shared(cpp11::list pars) { + const auto len = dust2::r::read_size(pars, "len", 10); + const auto sd = dust2::r::read_real(pars, "sd"); + return shared_state{len, sd}; + } + + static void update_shared(cpp11::list pars, shared_state& shared) { + shared.sd = dust2::r::read_real(pars, "sd", shared.sd); + } + + static internal_state build_internal(const shared_state& shared) { + std::vector scratch(shared.len); + return internal_state{scratch}; + } + + static void initial(real_type time, + const shared_state& shared, + internal_state& internal, + rng_state_type& rng_state, + real_type * state_next) { + std::fill(state_next, state_next + shared.len, 0); + } + + static void update(real_type time, + real_type dt, + const real_type * state, + const shared_state& shared, + internal_state& internal, + rng_state_type& rng_state, + real_type * state_next) { + const auto x = state[0]; + for (size_t i = 0; i < shared.len; ++i) { + internal.scratch[i] = + monty::random::normal(rng_state, x, shared.sd * dt); + } + state_next[0] = std::accumulate(internal.scratch.begin(), + internal.scratch.end(), + static_cast(0.0)) / shared.len; + } +}; diff --git a/tests/testthat/test-zzz-slow.R b/tests/testthat/test-zzz-slow.R index 1562b97d..4732e311 100644 --- a/tests/testthat/test-zzz-slow.R +++ b/tests/testthat/test-zzz-slow.R @@ -22,3 +22,18 @@ test_that("can use two zero_every", { expect_equal(y[1, ], rep(1:2, length.out = 20)) expect_equal(y[2, ], rep(1:5, length.out = 20)) }) + + +test_that("can run a model with internal storage", { + skip_for_compilation() + gen <- dust_compile("examples/internal.cpp", quiet = TRUE, debug = TRUE) + + len <- 5 + set.seed(1) + sys <- dust_system_create(gen(), list(sd = 1, len = len), 1) + seed <- dust_system_rng_state(sys) + y <- drop(dust_system_simulate(sys, 1:20)) + + r <- monty::monty_rng$new(seed = seed)$normal(20 * len, 0, 1) + expect_equal(y, cumsum(colMeans(matrix(r, len, 20)))) +}) From 1cf24f2423243316b954a0f287ec7ae62bd86d3a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 14 Oct 2024 16:35:30 +0100 Subject: [PATCH 2/2] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3fbd3df3..8d04fe88 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: dust2 Title: Next Generation dust -Version: 0.1.15 +Version: 0.1.16 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Imperial College of Science, Technology and Medicine",