From 198483adba4555b5cdba8f1b5632c7623fae7d6c Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 17 May 2024 16:19:57 +0100 Subject: [PATCH] Different order for rng --- inst/include/dust2/r/filter.hpp | 25 ++++++++++++++++++++++--- tests/testthat/test-filter.R | 10 +++++----- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/inst/include/dust2/r/filter.hpp b/inst/include/dust2/r/filter.hpp index 1afdd16d..490910a5 100644 --- a/inst/include/dust2/r/filter.hpp +++ b/inst/include/dust2/r/filter.hpp @@ -174,9 +174,28 @@ cpp11::sexp dust2_cpu_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, template cpp11::sexp dust2_cpu_filter_rng_state(cpp11::sexp ptr) { auto *obj = cpp11::as_cpp>>(ptr).get(); - const auto state_filter = rng_state_as_raw(obj->rng_state()); - const auto state_model = rng_state_as_raw(obj->model.rng_state()); - return cpp11::writable::list{state_filter, state_model}; + using rng_state_type = typename T::rng_state_type; + + // Undo the construction as above so that the rng state comes out in + // the same format it goes in, as a single raw vector. + const auto& state_filter = obj->rng_state(); + const auto& state_model = obj->model.rng_state(); + const auto n_particles = obj->model.n_particles(); + const auto n_groups = obj->model.n_groups(); + const auto n_state = rng_state_type::size(); + const auto n_bytes = sizeof(typename rng_state_type::int_type); + const auto n_bytes_state = n_bytes * n_state; + cpp11::writable::raws ret(n_bytes * (state_filter.size() + state_model.size())); + for (size_t i = 0; i < n_groups; ++i) { + std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1), + state_filter.data() + i * n_state, + n_bytes_state); + std::memcpy(RAW(ret) + i * n_bytes_state * (n_particles + 1) + n_bytes_state, + state_model.data() + i * n_state * n_particles, + n_bytes_state * n_particles); + } + + return ret; } } diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R index d5a79b4c..b9c6dc7c 100644 --- a/tests/testthat/test-filter.R +++ b/tests/testthat/test-filter.R @@ -136,7 +136,6 @@ test_that("can run particle filter", { obj <- dust2_cpu_sir_filter_alloc( pars, time_start, time, dt, data, n_particles, 0, seed) ptr <- obj[[1]] - s <- dust2_cpu_sir_filter_rng_state(ptr) res <- replicate(20, dust2_cpu_sir_filter_run(ptr, NULL, FALSE)) cmp_filter <- sir_filter_manual( @@ -166,10 +165,8 @@ test_that("can run a nested particle filter and get the same result", { ## Here, we can check the layout of the rng within the filter and model: n_streams <- (n_particles + 1) * 2 r <- mcstate2::mcstate_rng$new(n_streams = n_streams, seed = seed)$state() - rr <- array(r, c(length(r) / n_streams, n_particles + 1, 2)) s <- dust2_cpu_sir_filter_rng_state(ptr) - expect_equal(s[[1]], c(rr[, 1, ])) - expect_equal(s[[2]], c(rr[, -1, ])) + expect_equal(s, r) res <- replicate(20, dust2_cpu_sir_filter_run(ptr, NULL, TRUE)) @@ -181,12 +178,15 @@ test_that("can run a nested particle filter and get the same result", { s1 <- dust2_cpu_sir_filter_rng_state(ptr1) res1 <- replicate(20, dust2_cpu_sir_filter_run(ptr1, NULL, FALSE)) expect_equal(res1, res[1, ]) + expect_equal(s1, s[1:3232]) + seed2 <- r[3233:3264] data2 <- lapply(data, "[[", 2) obj2 <- dust2_cpu_sir_filter_alloc( - pars[[2]], time_start, time, dt, data2, n_particles, 0, rr[, 1, 2]) + pars[[2]], time_start, time, dt, data2, n_particles, 0, seed2) ptr2 <- obj2[[1]] s2 <- dust2_cpu_sir_filter_rng_state(ptr2) res2 <- replicate(20, dust2_cpu_sir_filter_run(ptr2, NULL, FALSE)) expect_equal(res2, res[2, ]) + expect_equal(s2, s[3233:6464]) })