Skip to content

Commit

Permalink
Different order for rng
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 17, 2024
1 parent 98738bd commit 198483a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
25 changes: 22 additions & 3 deletions inst/include/dust2/r/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,28 @@ cpp11::sexp dust2_cpu_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
template <typename T>
cpp11::sexp dust2_cpu_filter_rng_state(cpp11::sexp ptr) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<filter<T>>>(ptr).get();
const auto state_filter = rng_state_as_raw(obj->rng_state());
const auto state_model = rng_state_as_raw(obj->model.rng_state());
return cpp11::writable::list{state_filter, state_model};
using rng_state_type = typename T::rng_state_type;

// Undo the construction as above so that the rng state comes out in
// the same format it goes in, as a single raw vector.
const auto& state_filter = obj->rng_state();
const auto& state_model = obj->model.rng_state();
const auto n_particles = obj->model.n_particles();
const auto n_groups = obj->model.n_groups();
const auto n_state = rng_state_type::size();
const auto n_bytes = sizeof(typename rng_state_type::int_type);
const auto n_bytes_state = n_bytes * n_state;
cpp11::writable::raws ret(n_bytes * (state_filter.size() + state_model.size()));
for (size_t i = 0; i < n_groups; ++i) {
std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1),
state_filter.data() + i * n_state,
n_bytes_state);
std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1) + n_bytes_state,
state_model.data() + i * n_state * n_particles,
n_bytes_state * n_particles);
}

return ret;
}

}
Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ test_that("can run particle filter", {
obj <- dust2_cpu_sir_filter_alloc(
pars, time_start, time, dt, data, n_particles, 0, seed)
ptr <- obj[[1]]
s <- dust2_cpu_sir_filter_rng_state(ptr)
res <- replicate(20, dust2_cpu_sir_filter_run(ptr, NULL, FALSE))

cmp_filter <- sir_filter_manual(
Expand Down Expand Up @@ -166,10 +165,8 @@ test_that("can run a nested particle filter and get the same result", {
## Here, we can check the layout of the rng within the filter and model:
n_streams <- (n_particles + 1) * 2
r <- mcstate2::mcstate_rng$new(n_streams = n_streams, seed = seed)$state()
rr <- array(r, c(length(r) / n_streams, n_particles + 1, 2))
s <- dust2_cpu_sir_filter_rng_state(ptr)
expect_equal(s[[1]], c(rr[, 1, ]))
expect_equal(s[[2]], c(rr[, -1, ]))
expect_equal(s, r)

res <- replicate(20, dust2_cpu_sir_filter_run(ptr, NULL, TRUE))

Expand All @@ -181,12 +178,15 @@ test_that("can run a nested particle filter and get the same result", {
s1 <- dust2_cpu_sir_filter_rng_state(ptr1)
res1 <- replicate(20, dust2_cpu_sir_filter_run(ptr1, NULL, FALSE))
expect_equal(res1, res[1, ])
expect_equal(s1, s[1:3232])

seed2 <- r[3233:3264]
data2 <- lapply(data, "[[", 2)
obj2 <- dust2_cpu_sir_filter_alloc(
pars[[2]], time_start, time, dt, data2, n_particles, 0, rr[, 1, 2])
pars[[2]], time_start, time, dt, data2, n_particles, 0, seed2)
ptr2 <- obj2[[1]]
s2 <- dust2_cpu_sir_filter_rng_state(ptr2)
res2 <- replicate(20, dust2_cpu_sir_filter_run(ptr2, NULL, FALSE))
expect_equal(res2, res[2, ])
expect_equal(s2, s[3233:6464])
})

0 comments on commit 198483a

Please sign in to comment.