diff --git a/R/cpp11.R b/R/cpp11.R index a153ab81..d6c40ccd 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -72,6 +72,10 @@ test_resample_weight <- function(w, u) { .Call(`_dust2_test_resample_weight`, w, u) } +test_history <- function(r_time, r_state, r_order, reorder) { + .Call(`_dust2_test_history`, r_time, r_state, r_order, reorder) +} + dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) { .Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) } diff --git a/inst/include/dust2/history.hpp b/inst/include/dust2/history.hpp new file mode 100644 index 00000000..17b5f669 --- /dev/null +++ b/inst/include/dust2/history.hpp @@ -0,0 +1,144 @@ +#pragma once + +#include +#include + +namespace dust2 { + +// We might want a version of this that saves a subset of state too, +// we can think about that later though. We also need a version that +// can allow working back through the graph of history. +template +class history { +public: + history(size_t n_state, size_t n_particles, size_t n_groups, size_t n_times) : + n_state_(n_state), + n_particles_(n_particles), + n_groups_(n_groups), + n_times_(n_times), + len_state_(n_state_ * n_particles_ * n_groups_), + len_order_(n_particles * n_groups_), + position_(0), + times_(n_times_), + state_(len_state_ * n_times_), + order_(len_order_ * n_times_), + reorder_(n_times) { + } + + void resize(size_t n_times) { + n_times_ = n_times; + state_.resize(len_state_ * n_times_); + order_.resize(len_order_ * n_times_); + reset(); + } + + void reset() { + position_ = 0; + } + + template + void add(real_type time, Iter iter) { + // TODO: bounds check here (and in the method below)? + std::copy_n(iter, len_state_, state_.begin() + len_state_ * position_); + times_[position_] = time; + reorder_[position_] = false; + position_++; + } + + template + void add(real_type time, IterReal iter_state, IterSize iter_order) { + // This can't easily call add(Iter, real_type) because we need + // read position_ and write reorder_; the duplication is minimal + // though. + std::copy_n(iter_state, len_state_, + state_.begin() + position_ * len_state_); + std::copy_n(iter_order, len_order_, + order_.begin() + position_ * len_order_); + times_[position_] = time; + reorder_[position_] = true; + position_++; + } + + // These allow a consumer to allocate the right size structures for + // time and state for the total that we've actually used. + auto size_time() const { + return position_; + } + + auto size_state() const { + return position_ * len_state_; + } + + template + void export_time(Iter iter) { + std::copy_n(times_.begin(), position_, iter); + } + + template + void export_state(Iter iter, bool reorder) { + reorder = reorder && n_particles_ > 1 && position_ > 0 && + std::any_of(reorder_.begin(), reorder_.end(), [](auto v) { return v; }); + if (reorder) { + // Default index: + std::vector index_particle(n_particles_ * n_groups_); + for (size_t i = 0, k = 0; i < n_groups_; ++i) { + for (size_t j = 0; j < n_particles_; ++j, ++k) { + index_particle[k] = j; + } + } + + for (size_t irev = 0; irev < position_; ++irev) { + const auto i = position_ - irev - 1; // can move this to be the loop + const auto iter_order = order_.begin() + i * len_order_; + const auto iter_state = state_.begin() + i * len_state_; + // This bit here is independent among groups + for (size_t j = 0; j < n_groups_; j++) { + const auto offset_state = j * n_state_ * n_particles_; + const auto offset_index = j * n_particles_; + reorder_group_(iter_state + offset_state, + iter_order + offset_index, + reorder_[i], + iter + i * len_state_ + offset_state, + index_particle.begin() + offset_index); + } + } + } else { + // No reordering is requested or possible so just dump out directly: + std::copy_n(state_.begin(), position_ * len_state_, iter); + } + } + +private: + size_t n_state_; + size_t n_particles_; + size_t n_groups_; + size_t n_times_; + size_t len_state_; // length of an update to state + size_t len_order_; // length of an update to order + size_t position_; + std::vector times_; + std::vector state_; + std::vector order_; + std::vector reorder_; + + // Reference implementation for this is mcstate:::history_single and + // mcstate::history_multiple + template + void reorder_group_(typename std::vector::const_iterator iter_state, + typename std::vector::const_iterator iter_order, + bool reorder, + Iter iter_dest, + typename std::vector::iterator index) { + for (size_t i = 0; i < n_particles_; ++i) { + std::copy_n(iter_state + *(index + i) * n_state_, + n_state_, + iter_dest + i * n_state_); + if (reorder) { + const auto index_i = index + i; + *index_i = *(iter_order + *index_i); + } + } + } +}; + +} diff --git a/inst/include/dust2/r/helpers.hpp b/inst/include/dust2/r/helpers.hpp index 1fd4fd15..30024326 100644 --- a/inst/include/dust2/r/helpers.hpp +++ b/inst/include/dust2/r/helpers.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace dust2 { namespace r { diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 29bc6b6e..8edd8ece 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -131,6 +131,13 @@ extern "C" SEXP _dust2_test_resample_weight(SEXP w, SEXP u) { return cpp11::as_sexp(test_resample_weight(cpp11::as_cpp>>(w), cpp11::as_cpp>(u))); END_CPP11 } +// test.cpp +cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state, cpp11::sexp r_order, bool reorder); +extern "C" SEXP _dust2_test_history(SEXP r_time, SEXP r_state, SEXP r_order, SEXP reorder) { + BEGIN_CPP11 + return cpp11::as_sexp(test_history(cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_state), cpp11::as_cpp>(r_order), cpp11::as_cpp>(reorder))); + END_CPP11 +} // walk.cpp SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic); extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) { @@ -239,6 +246,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2}, {"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1}, {"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3}, + {"_dust2_test_history", (DL_FUNC) &_dust2_test_history, 4}, {"_dust2_test_resample_weight", (DL_FUNC) &_dust2_test_resample_weight, 2}, {NULL, NULL, 0} }; diff --git a/src/test.cpp b/src/test.cpp index b758d910..096b1cc0 100644 --- a/src/test.cpp +++ b/src/test.cpp @@ -1,7 +1,10 @@ #include +#include +#include #include #include +#include [[cpp11::register]] cpp11::integers test_resample_weight(std::vector w, double u) { @@ -11,3 +14,39 @@ cpp11::integers test_resample_weight(std::vector w, double u) { cpp11::writable::integers ret(idx.begin(), idx.end()); return ret; } + +// Simple driver for exercising the history saving outside of any +// particle filter. +[[cpp11::register]] +cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state, + cpp11::sexp r_order, bool reorder) { + const size_t n_times = r_time.size(); + cpp11::sexp el0 = r_state[0]; + + auto r_dim = cpp11::as_cpp(el0.attr("dim")); + const size_t n_state = r_dim[0]; + const size_t n_particles = r_dim[1]; + const size_t n_groups = r_dim[2]; + + dust2::history h(n_state, n_particles, n_groups, n_times); + for (size_t i = 0; i < static_cast(r_state.size()); ++i) { + if (r_order == R_NilValue) { + h.add(r_time[i], REAL(r_state[i])); + } else { + cpp11::sexp el = cpp11::as_cpp(r_order)[i]; + if (el == R_NilValue) { + h.add(r_time[i], REAL(r_state[i])); + } else { + h.add(r_time[i], REAL(r_state[i]), INTEGER(el)); + } + } + } + + cpp11::writable::doubles ret_time(static_cast(h.size_time())); + cpp11::writable::doubles ret_state(static_cast(h.size_state())); + h.export_time(REAL(ret_time)); + h.export_state(REAL(ret_state), reorder); + dust2::r::set_array_dims(ret_state, {n_state, n_particles, n_groups, h.size_time()}); + + return cpp11::writable::list{ret_time, ret_state}; +} diff --git a/tests/testthat/test-filter-details.R b/tests/testthat/test-filter-details.R index caab81f4..613a8f73 100644 --- a/tests/testthat/test-filter-details.R +++ b/tests/testthat/test-filter-details.R @@ -14,3 +14,102 @@ test_that("Resampling works as expected", { test_resample_weight(w, u) + 1L, ref_resample_weight(w, u)) }) + + +test_that("can use history", { + time <- seq(0, 10, length.out = 11) + n_state <- 6 + n_particles <- 7 + n_groups <- 3 + n_time <- length(time) + s <- lapply(seq_along(time), function(i) { + array(runif(n_state * n_particles * n_groups), + c(n_state, n_particles, n_groups)) + }) + s_arr <- array(unlist(s), c(n_state, n_particles, n_groups, n_time)) + expect_equal(test_history(time, s, NULL, TRUE), + list(time, s_arr)) + expect_equal(test_history(time, s, NULL, FALSE), + list(time, s_arr)) + expect_equal(test_history(time, s[1:3], NULL, TRUE), + list(time[1:3], s_arr[, , , 1:3])) + expect_equal(test_history(time, s, vector("list", length(time)), TRUE), + list(time, s_arr)) +}) + + +test_that("can reorder history with no groups", { + ## This is really hard to get right so let's actually simulate forward: + time <- seq(0, 10, length.out = 11) + n_time <- length(time) + n_state <- 6 + n_particles <- 7 + n_groups <- 1 + state <- vector("list", length(time)) + order <- vector("list", length(time)) + true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time)) + s <- array(0, c(n_state, n_particles, n_groups)) + set.seed(1) + for (i in seq_along(time)) { + s <- s + runif(length(s)) + if (i > 1 && i %% 2 == 0) { + k <- replicate(n_groups, sample(n_particles, replace = TRUE)) + order[[i]] <- as.integer(k - 1L) + for (j in seq_len(n_groups)) { + s[, , j] <- s[, k[, j], j] + 1 + true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)] + } + } + state[[i]] <- s + true[, , , i] <- s + } + state_arr <- array(unlist(state), dim(true)) + + ## Pass in, but ignore index + expect_equal(test_history(time, state, order, FALSE), + list(time, state_arr)) + + ## Really simple, add an index that does not reorder anything: + expect_equal(test_history(time, state[1], order[1], TRUE), + list(time[1], state_arr[, , , 1, drop = FALSE])) + expect_equal(test_history(time, state[1:2], list(NULL, 0:6), TRUE), + list(time[1:2], state_arr[, , , 1:2, drop = FALSE])) + + ## Proper reordering with the full index: + expect_equal(test_history(time, state, order, TRUE), + list(time, true)) +}) + + +test_that("can reorder history on the way out", { + ## This is really hard to get right so let's actually simulate forward: + time <- seq(0, 10, length.out = 11) + n_time <- length(time) + n_state <- 6 + n_particles <- 7 + n_groups <- 3 + state <- vector("list", length(time)) + order <- vector("list", length(time)) + true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time)) + s <- array(0, c(n_state, n_particles, n_groups)) + set.seed(1) + for (i in seq_along(time)) { + s <- s + runif(length(s)) + if (i > 1 && i %% 2 == 0) { + k <- replicate(n_groups, sample(n_particles, replace = TRUE)) + order[[i]] <- as.integer(k - 1L) + for (j in seq_len(n_groups)) { + s[, , j] <- s[, k[, j], j] + 1 + true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)] + } + } + state[[i]] <- s + true[, , , i] <- s + } + + state_arr <- array(unlist(state), dim(true)) + expect_equal(test_history(time, state, order, FALSE), + list(time, state_arr)) + expect_equal(test_history(time, state, order, TRUE), + list(time, true)) +})