Skip to content

Commit

Permalink
Merge pull request #15 from mrc-ide/mrc-5377
Browse files Browse the repository at this point in the history
Utility class for saving history
  • Loading branch information
weshinsley authored May 22, 2024
2 parents f90cf40 + 2440860 commit 2a20cfd
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 0 deletions.
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ test_resample_weight <- function(w, u) {
.Call(`_dust2_test_resample_weight`, w, u)
}

test_history <- function(r_time, r_state, r_order, reorder) {
.Call(`_dust2_test_history`, r_time, r_state, r_order, reorder)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}
Expand Down
144 changes: 144 additions & 0 deletions inst/include/dust2/history.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#pragma once

#include <algorithm>
#include <vector>

namespace dust2 {

// We might want a version of this that saves a subset of state too,
// we can think about that later though. We also need a version that
// can allow working back through the graph of history.
template <typename real_type>
class history {
public:
history(size_t n_state, size_t n_particles, size_t n_groups, size_t n_times) :
n_state_(n_state),
n_particles_(n_particles),
n_groups_(n_groups),
n_times_(n_times),
len_state_(n_state_ * n_particles_ * n_groups_),
len_order_(n_particles * n_groups_),
position_(0),
times_(n_times_),
state_(len_state_ * n_times_),
order_(len_order_ * n_times_),
reorder_(n_times) {
}

void resize(size_t n_times) {
n_times_ = n_times;
state_.resize(len_state_ * n_times_);
order_.resize(len_order_ * n_times_);
reset();
}

void reset() {
position_ = 0;
}

template <typename Iter>
void add(real_type time, Iter iter) {
// TODO: bounds check here (and in the method below)?
std::copy_n(iter, len_state_, state_.begin() + len_state_ * position_);
times_[position_] = time;
reorder_[position_] = false;
position_++;
}

template <typename IterReal, typename IterSize>
void add(real_type time, IterReal iter_state, IterSize iter_order) {
// This can't easily call add(Iter, real_type) because we need
// read position_ and write reorder_; the duplication is minimal
// though.
std::copy_n(iter_state, len_state_,
state_.begin() + position_ * len_state_);
std::copy_n(iter_order, len_order_,
order_.begin() + position_ * len_order_);
times_[position_] = time;
reorder_[position_] = true;
position_++;
}

// These allow a consumer to allocate the right size structures for
// time and state for the total that we've actually used.
auto size_time() const {
return position_;
}

auto size_state() const {
return position_ * len_state_;
}

template <typename Iter>
void export_time(Iter iter) {
std::copy_n(times_.begin(), position_, iter);
}

template <typename Iter>
void export_state(Iter iter, bool reorder) {
reorder = reorder && n_particles_ > 1 && position_ > 0 &&
std::any_of(reorder_.begin(), reorder_.end(), [](auto v) { return v; });
if (reorder) {
// Default index:
std::vector<size_t> index_particle(n_particles_ * n_groups_);
for (size_t i = 0, k = 0; i < n_groups_; ++i) {
for (size_t j = 0; j < n_particles_; ++j, ++k) {
index_particle[k] = j;
}
}

for (size_t irev = 0; irev < position_; ++irev) {
const auto i = position_ - irev - 1; // can move this to be the loop
const auto iter_order = order_.begin() + i * len_order_;
const auto iter_state = state_.begin() + i * len_state_;
// This bit here is independent among groups
for (size_t j = 0; j < n_groups_; j++) {
const auto offset_state = j * n_state_ * n_particles_;
const auto offset_index = j * n_particles_;
reorder_group_(iter_state + offset_state,
iter_order + offset_index,
reorder_[i],
iter + i * len_state_ + offset_state,
index_particle.begin() + offset_index);
}
}
} else {
// No reordering is requested or possible so just dump out directly:
std::copy_n(state_.begin(), position_ * len_state_, iter);
}
}

private:
size_t n_state_;
size_t n_particles_;
size_t n_groups_;
size_t n_times_;
size_t len_state_; // length of an update to state
size_t len_order_; // length of an update to order
size_t position_;
std::vector<real_type> times_;
std::vector<real_type> state_;
std::vector<size_t> order_;
std::vector<bool> reorder_;

// Reference implementation for this is mcstate:::history_single and
// mcstate::history_multiple
template <typename Iter>
void reorder_group_(typename std::vector<real_type>::const_iterator iter_state,
typename std::vector<size_t>::const_iterator iter_order,
bool reorder,
Iter iter_dest,
typename std::vector<size_t>::iterator index) {
for (size_t i = 0; i < n_particles_; ++i) {
std::copy_n(iter_state + *(index + i) * n_state_,
n_state_,
iter_dest + i * n_state_);
if (reorder) {
const auto index_i = index + i;
*index_i = *(iter_order + *index_i);
}
}
}
};

}
1 change: 1 addition & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>
#include <dust2/common.hpp>
#include <dust2/cpu.hpp>
#include <cpp11.hpp>

namespace dust2 {
namespace r {
Expand Down
8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ extern "C" SEXP _dust2_test_resample_weight(SEXP w, SEXP u) {
return cpp11::as_sexp(test_resample_weight(cpp11::as_cpp<cpp11::decay_t<std::vector<double>>>(w), cpp11::as_cpp<cpp11::decay_t<double>>(u)));
END_CPP11
}
// test.cpp
cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state, cpp11::sexp r_order, bool reorder);
extern "C" SEXP _dust2_test_history(SEXP r_time, SEXP r_state, SEXP r_order, SEXP reorder) {
BEGIN_CPP11
return cpp11::as_sexp(test_history(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_time), cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(r_state), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_order), cpp11::as_cpp<cpp11::decay_t<bool>>(reorder)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic);
extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) {
Expand Down Expand Up @@ -239,6 +246,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2},
{"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1},
{"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3},
{"_dust2_test_history", (DL_FUNC) &_dust2_test_history, 4},
{"_dust2_test_resample_weight", (DL_FUNC) &_dust2_test_resample_weight, 2},
{NULL, NULL, 0}
};
Expand Down
39 changes: 39 additions & 0 deletions src/test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <dust2/filter_details.hpp>
#include <dust2/history.hpp>
#include <dust2/r/helpers.hpp>

#include <cpp11/integers.hpp>
#include <cpp11/doubles.hpp>
#include <cpp11/list.hpp>

[[cpp11::register]]
cpp11::integers test_resample_weight(std::vector<double> w, double u) {
Expand All @@ -11,3 +14,39 @@ cpp11::integers test_resample_weight(std::vector<double> w, double u) {
cpp11::writable::integers ret(idx.begin(), idx.end());
return ret;
}

// Simple driver for exercising the history saving outside of any
// particle filter.
[[cpp11::register]]
cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state,
cpp11::sexp r_order, bool reorder) {
const size_t n_times = r_time.size();
cpp11::sexp el0 = r_state[0];

auto r_dim = cpp11::as_cpp<cpp11::integers>(el0.attr("dim"));
const size_t n_state = r_dim[0];
const size_t n_particles = r_dim[1];
const size_t n_groups = r_dim[2];

dust2::history<double> h(n_state, n_particles, n_groups, n_times);
for (size_t i = 0; i < static_cast<size_t>(r_state.size()); ++i) {
if (r_order == R_NilValue) {
h.add(r_time[i], REAL(r_state[i]));
} else {
cpp11::sexp el = cpp11::as_cpp<cpp11::list>(r_order)[i];
if (el == R_NilValue) {
h.add(r_time[i], REAL(r_state[i]));
} else {
h.add(r_time[i], REAL(r_state[i]), INTEGER(el));
}
}
}

cpp11::writable::doubles ret_time(static_cast<int>(h.size_time()));
cpp11::writable::doubles ret_state(static_cast<int>(h.size_state()));
h.export_time(REAL(ret_time));
h.export_state(REAL(ret_state), reorder);
dust2::r::set_array_dims(ret_state, {n_state, n_particles, n_groups, h.size_time()});

return cpp11::writable::list{ret_time, ret_state};
}
99 changes: 99 additions & 0 deletions tests/testthat/test-filter-details.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,102 @@ test_that("Resampling works as expected", {
test_resample_weight(w, u) + 1L,
ref_resample_weight(w, u))
})


test_that("can use history", {
time <- seq(0, 10, length.out = 11)
n_state <- 6
n_particles <- 7
n_groups <- 3
n_time <- length(time)
s <- lapply(seq_along(time), function(i) {
array(runif(n_state * n_particles * n_groups),
c(n_state, n_particles, n_groups))
})
s_arr <- array(unlist(s), c(n_state, n_particles, n_groups, n_time))
expect_equal(test_history(time, s, NULL, TRUE),
list(time, s_arr))
expect_equal(test_history(time, s, NULL, FALSE),
list(time, s_arr))
expect_equal(test_history(time, s[1:3], NULL, TRUE),
list(time[1:3], s_arr[, , , 1:3]))
expect_equal(test_history(time, s, vector("list", length(time)), TRUE),
list(time, s_arr))
})


test_that("can reorder history with no groups", {
## This is really hard to get right so let's actually simulate forward:
time <- seq(0, 10, length.out = 11)
n_time <- length(time)
n_state <- 6
n_particles <- 7
n_groups <- 1
state <- vector("list", length(time))
order <- vector("list", length(time))
true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time))
s <- array(0, c(n_state, n_particles, n_groups))
set.seed(1)
for (i in seq_along(time)) {
s <- s + runif(length(s))
if (i > 1 && i %% 2 == 0) {
k <- replicate(n_groups, sample(n_particles, replace = TRUE))
order[[i]] <- as.integer(k - 1L)
for (j in seq_len(n_groups)) {
s[, , j] <- s[, k[, j], j] + 1
true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)]
}
}
state[[i]] <- s
true[, , , i] <- s
}
state_arr <- array(unlist(state), dim(true))

## Pass in, but ignore index
expect_equal(test_history(time, state, order, FALSE),
list(time, state_arr))

## Really simple, add an index that does not reorder anything:
expect_equal(test_history(time, state[1], order[1], TRUE),
list(time[1], state_arr[, , , 1, drop = FALSE]))
expect_equal(test_history(time, state[1:2], list(NULL, 0:6), TRUE),
list(time[1:2], state_arr[, , , 1:2, drop = FALSE]))

## Proper reordering with the full index:
expect_equal(test_history(time, state, order, TRUE),
list(time, true))
})


test_that("can reorder history on the way out", {
## This is really hard to get right so let's actually simulate forward:
time <- seq(0, 10, length.out = 11)
n_time <- length(time)
n_state <- 6
n_particles <- 7
n_groups <- 3
state <- vector("list", length(time))
order <- vector("list", length(time))
true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time))
s <- array(0, c(n_state, n_particles, n_groups))
set.seed(1)
for (i in seq_along(time)) {
s <- s + runif(length(s))
if (i > 1 && i %% 2 == 0) {
k <- replicate(n_groups, sample(n_particles, replace = TRUE))
order[[i]] <- as.integer(k - 1L)
for (j in seq_len(n_groups)) {
s[, , j] <- s[, k[, j], j] + 1
true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)]
}
}
state[[i]] <- s
true[, , , i] <- s
}

state_arr <- array(unlist(state), dim(true))
expect_equal(test_history(time, state, order, FALSE),
list(time, state_arr))
expect_equal(test_history(time, state, order, TRUE),
list(time, true))
})

0 comments on commit 2a20cfd

Please sign in to comment.