Skip to content

Commit

Permalink
Fix crash using internal
Browse files Browse the repository at this point in the history
Adds integration test too
  • Loading branch information
richfitz committed Oct 14, 2024
1 parent b1d07fe commit 9562ada
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 3 deletions.
11 changes: 8 additions & 3 deletions inst/include/dust2/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ namespace dust2 {

namespace internals {

template <class T, class = void>
struct test_has_build_internal: std::false_type {};
template <class T>
struct test_has_build_internal<T, std::void_t<decltype(T::build_internal)>>: std::true_type {};

template <class T, class = void>
struct test_has_packing_gradient: std::false_type {};
template <class T>
Expand All @@ -32,20 +37,20 @@ struct test_has_rhs<T, std::void_t<decltype(T::rhs)>>: std::true_type {};

template<typename T>
struct properties {
using has_internal_state = std::is_same<std::is_empty<typename T::internal_state>, std::false_type>;
using has_build_internal = internals::test_has_build_internal<T>;
using has_packing_gradient = internals::test_has_packing_gradient<T>;
using has_zero_every = internals::test_has_zero_every<T>;
using is_mixed_time = typename std::conditional<internals::test_has_rhs<T>::value && internals::test_has_update<T>::value, std::true_type, std::false_type>::type;
};

// wrappers around some uses of member functions that may or may not
// exist, centralising most of the weird into this file:
template <typename T, typename std::enable_if<properties<T>::has_internal_state::value, T>::type* = nullptr>
template <typename T, typename std::enable_if<properties<T>::has_build_internal::value, T>::type* = nullptr>
typename T::internal_state do_build_internal(const typename T::shared_state &shared) {
return T::build_internal(shared);
}

template <typename T, typename std::enable_if<!properties<T>::has_internal_state::value, T>::type* = nullptr>
template <typename T, typename std::enable_if<!properties<T>::has_build_internal::value, T>::type* = nullptr>
typename T::internal_state do_build_internal(const typename T::shared_state &shared) {
return typename T::internal_state{};
}
Expand Down
66 changes: 66 additions & 0 deletions tests/testthat/examples/internal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <dust2/common.hpp>

// [[dust2::class(walk)]]
// [[dust2::time_type(discrete)]]
// [[dust2::parameter(sd)]]
// [[dust2::parameter(len)]]
class walk {
public:
walk() = delete;

using real_type = double;
using rng_state_type = monty::random::generator<real_type>;

struct shared_state {
size_t len;
real_type sd;
};

struct internal_state {
std::vector<real_type> scratch;
};

static dust2::packing packing_state(const shared_state& shared) {
return dust2::packing{{"x", {}}};
}

static shared_state build_shared(cpp11::list pars) {
const auto len = dust2::r::read_size(pars, "len", 10);
const auto sd = dust2::r::read_real(pars, "sd");
return shared_state{len, sd};
}

static void update_shared(cpp11::list pars, shared_state& shared) {
shared.sd = dust2::r::read_real(pars, "sd", shared.sd);
}

static internal_state build_internal(const shared_state& shared) {
std::vector<real_type> scratch(shared.len);
return internal_state{scratch};
}

static void initial(real_type time,
const shared_state& shared,
internal_state& internal,
rng_state_type& rng_state,
real_type * state_next) {
std::fill(state_next, state_next + shared.len, 0);
}

static void update(real_type time,
real_type dt,
const real_type * state,
const shared_state& shared,
internal_state& internal,
rng_state_type& rng_state,
real_type * state_next) {
const auto x = state[0];
for (size_t i = 0; i < shared.len; ++i) {
internal.scratch[i] =
monty::random::normal(rng_state, x, shared.sd * dt);
}
state_next[0] = std::accumulate(internal.scratch.begin(),
internal.scratch.end(),
static_cast<real_type>(0.0)) / shared.len;
}
};
15 changes: 15 additions & 0 deletions tests/testthat/test-zzz-slow.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ test_that("can use two zero_every", {
expect_equal(y[1, ], rep(1:2, length.out = 20))
expect_equal(y[2, ], rep(1:5, length.out = 20))
})


test_that("can run a model with internal storage", {
skip_for_compilation()
gen <- dust_compile("examples/internal.cpp", quiet = TRUE, debug = TRUE)

len <- 5
set.seed(1)
sys <- dust_system_create(gen(), list(sd = 1, len = len), 1)
seed <- dust_system_rng_state(sys)
y <- drop(dust_system_simulate(sys, 1:20))

r <- monty::monty_rng$new(seed = seed)$normal(20 * len, 0, 1)
expect_equal(y, cumsum(colMeans(matrix(r, len, 20))))
})

0 comments on commit 9562ada

Please sign in to comment.