Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility class for saving history #15

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ test_resample_weight <- function(w, u) {
.Call(`_dust2_test_resample_weight`, w, u)
}

test_history <- function(r_time, r_state, r_order, reorder) {
.Call(`_dust2_test_history`, r_time, r_state, r_order, reorder)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}
Expand Down
144 changes: 144 additions & 0 deletions inst/include/dust2/history.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#pragma once

#include <algorithm>
#include <vector>

namespace dust2 {

// We might want a version of this that saves a subset of state too,
// we can think about that later though. We also need a version that
// can allow working back through the graph of history.
template <typename real_type>
class history {
public:
history(size_t n_state, size_t n_particles, size_t n_groups, size_t n_times) :
n_state_(n_state),
n_particles_(n_particles),
n_groups_(n_groups),
n_times_(n_times),
len_state_(n_state_ * n_particles_ * n_groups_),
len_order_(n_particles * n_groups_),
position_(0),
times_(n_times_),
state_(len_state_ * n_times_),
order_(len_order_ * n_times_),
reorder_(n_times) {
}

void resize(size_t n_times) {
n_times_ = n_times;
state_.resize(len_state_ * n_times_);
order_.resize(len_order_ * n_times_);
reset();
}

void reset() {
position_ = 0;
}

template <typename Iter>
void add(real_type time, Iter iter) {
// TODO: bounds check here (and in the method below)?
std::copy_n(iter, len_state_, state_.begin() + len_state_ * position_);
times_[position_] = time;
reorder_[position_] = false;
position_++;
}

template <typename IterReal, typename IterSize>
void add(real_type time, IterReal iter_state, IterSize iter_order) {
// This can't easily call add(Iter, real_type) because we need
// read position_ and write reorder_; the duplication is minimal
// though.
std::copy_n(iter_state, len_state_,
state_.begin() + position_ * len_state_);
std::copy_n(iter_order, len_order_,
order_.begin() + position_ * len_order_);
times_[position_] = time;
reorder_[position_] = true;
position_++;
}

// These allow a consumer to allocate the right size structures for
// time and state for the total that we've actually used.
auto size_time() const {
return position_;
}

auto size_state() const {
return position_ * len_state_;
}

template <typename Iter>
void export_time(Iter iter) {
std::copy_n(times_.begin(), position_, iter);
}

template <typename Iter>
void export_state(Iter iter, bool reorder) {
reorder = reorder && n_particles_ > 1 && position_ > 0 &&
std::any_of(reorder_.begin(), reorder_.end(), [](auto v) { return v; });
if (reorder) {
// Default index:
std::vector<size_t> index_particle(n_particles_ * n_groups_);
for (size_t i = 0, k = 0; i < n_groups_; ++i) {
for (size_t j = 0; j < n_particles_; ++j, ++k) {
index_particle[k] = j;
}
}

for (size_t irev = 0; irev < position_; ++irev) {
const auto i = position_ - irev - 1; // can move this to be the loop
const auto iter_order = order_.begin() + i * len_order_;
const auto iter_state = state_.begin() + i * len_state_;
// This bit here is independent among groups
for (size_t j = 0; j < n_groups_; j++) {
const auto offset_state = j * n_state_ * n_particles_;
const auto offset_index = j * n_particles_;
reorder_group_(iter_state + offset_state,
iter_order + offset_index,
reorder_[i],
iter + i * len_state_ + offset_state,
index_particle.begin() + offset_index);
}
}
} else {
// No reordering is requested or possible so just dump out directly:
std::copy_n(state_.begin(), position_ * len_state_, iter);
}
}

private:
size_t n_state_;
size_t n_particles_;
size_t n_groups_;
size_t n_times_;
size_t len_state_; // length of an update to state
size_t len_order_; // length of an update to order
size_t position_;
std::vector<real_type> times_;
std::vector<real_type> state_;
std::vector<size_t> order_;
std::vector<bool> reorder_;

// Reference implementation for this is mcstate:::history_single and
// mcstate::history_multiple
template <typename Iter>
void reorder_group_(typename std::vector<real_type>::const_iterator iter_state,
typename std::vector<size_t>::const_iterator iter_order,
bool reorder,
Iter iter_dest,
typename std::vector<size_t>::iterator index) {
for (size_t i = 0; i < n_particles_; ++i) {
std::copy_n(iter_state + *(index + i) * n_state_,
n_state_,
iter_dest + i * n_state_);
if (reorder) {
const auto index_i = index + i;
*index_i = *(iter_order + *index_i);
}
}
}
};

}
1 change: 1 addition & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>
#include <dust2/common.hpp>
#include <dust2/cpu.hpp>
#include <cpp11.hpp>

namespace dust2 {
namespace r {
Expand Down
8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ extern "C" SEXP _dust2_test_resample_weight(SEXP w, SEXP u) {
return cpp11::as_sexp(test_resample_weight(cpp11::as_cpp<cpp11::decay_t<std::vector<double>>>(w), cpp11::as_cpp<cpp11::decay_t<double>>(u)));
END_CPP11
}
// test.cpp
cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state, cpp11::sexp r_order, bool reorder);
extern "C" SEXP _dust2_test_history(SEXP r_time, SEXP r_state, SEXP r_order, SEXP reorder) {
BEGIN_CPP11
return cpp11::as_sexp(test_history(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_time), cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(r_state), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_order), cpp11::as_cpp<cpp11::decay_t<bool>>(reorder)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic);
extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) {
Expand Down Expand Up @@ -239,6 +246,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2},
{"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1},
{"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3},
{"_dust2_test_history", (DL_FUNC) &_dust2_test_history, 4},
{"_dust2_test_resample_weight", (DL_FUNC) &_dust2_test_resample_weight, 2},
{NULL, NULL, 0}
};
Expand Down
39 changes: 39 additions & 0 deletions src/test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <dust2/filter_details.hpp>
#include <dust2/history.hpp>
#include <dust2/r/helpers.hpp>

#include <cpp11/integers.hpp>
#include <cpp11/doubles.hpp>
#include <cpp11/list.hpp>

[[cpp11::register]]
cpp11::integers test_resample_weight(std::vector<double> w, double u) {
Expand All @@ -11,3 +14,39 @@ cpp11::integers test_resample_weight(std::vector<double> w, double u) {
cpp11::writable::integers ret(idx.begin(), idx.end());
return ret;
}

// Simple driver for exercising the history saving outside of any
// particle filter.
[[cpp11::register]]
cpp11::sexp test_history(cpp11::doubles r_time, cpp11::list r_state,
cpp11::sexp r_order, bool reorder) {
const size_t n_times = r_time.size();
cpp11::sexp el0 = r_state[0];

auto r_dim = cpp11::as_cpp<cpp11::integers>(el0.attr("dim"));
const size_t n_state = r_dim[0];
const size_t n_particles = r_dim[1];
const size_t n_groups = r_dim[2];

dust2::history<double> h(n_state, n_particles, n_groups, n_times);
for (size_t i = 0; i < static_cast<size_t>(r_state.size()); ++i) {
if (r_order == R_NilValue) {
h.add(r_time[i], REAL(r_state[i]));
} else {
cpp11::sexp el = cpp11::as_cpp<cpp11::list>(r_order)[i];
if (el == R_NilValue) {
h.add(r_time[i], REAL(r_state[i]));
} else {
h.add(r_time[i], REAL(r_state[i]), INTEGER(el));
}
}
}

cpp11::writable::doubles ret_time(static_cast<int>(h.size_time()));
cpp11::writable::doubles ret_state(static_cast<int>(h.size_state()));
h.export_time(REAL(ret_time));
h.export_state(REAL(ret_state), reorder);
dust2::r::set_array_dims(ret_state, {n_state, n_particles, n_groups, h.size_time()});

return cpp11::writable::list{ret_time, ret_state};
}
99 changes: 99 additions & 0 deletions tests/testthat/test-filter-details.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,102 @@ test_that("Resampling works as expected", {
test_resample_weight(w, u) + 1L,
ref_resample_weight(w, u))
})


test_that("can use history", {
time <- seq(0, 10, length.out = 11)
n_state <- 6
n_particles <- 7
n_groups <- 3
n_time <- length(time)
s <- lapply(seq_along(time), function(i) {
array(runif(n_state * n_particles * n_groups),
c(n_state, n_particles, n_groups))
})
s_arr <- array(unlist(s), c(n_state, n_particles, n_groups, n_time))
expect_equal(test_history(time, s, NULL, TRUE),
list(time, s_arr))
expect_equal(test_history(time, s, NULL, FALSE),
list(time, s_arr))
expect_equal(test_history(time, s[1:3], NULL, TRUE),
list(time[1:3], s_arr[, , , 1:3]))
expect_equal(test_history(time, s, vector("list", length(time)), TRUE),
list(time, s_arr))
})


test_that("can reorder history with no groups", {
## This is really hard to get right so let's actually simulate forward:
time <- seq(0, 10, length.out = 11)
n_time <- length(time)
n_state <- 6
n_particles <- 7
n_groups <- 1
state <- vector("list", length(time))
order <- vector("list", length(time))
true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time))
s <- array(0, c(n_state, n_particles, n_groups))
set.seed(1)
for (i in seq_along(time)) {
s <- s + runif(length(s))
if (i > 1 && i %% 2 == 0) {
k <- replicate(n_groups, sample(n_particles, replace = TRUE))
order[[i]] <- as.integer(k - 1L)
for (j in seq_len(n_groups)) {
s[, , j] <- s[, k[, j], j] + 1
true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)]
}
}
state[[i]] <- s
true[, , , i] <- s
}
state_arr <- array(unlist(state), dim(true))

## Pass in, but ignore index
expect_equal(test_history(time, state, order, FALSE),
list(time, state_arr))

## Really simple, add an index that does not reorder anything:
expect_equal(test_history(time, state[1], order[1], TRUE),
list(time[1], state_arr[, , , 1, drop = FALSE]))
expect_equal(test_history(time, state[1:2], list(NULL, 0:6), TRUE),
list(time[1:2], state_arr[, , , 1:2, drop = FALSE]))

## Proper reordering with the full index:
expect_equal(test_history(time, state, order, TRUE),
list(time, true))
})


test_that("can reorder history on the way out", {
## This is really hard to get right so let's actually simulate forward:
time <- seq(0, 10, length.out = 11)
n_time <- length(time)
n_state <- 6
n_particles <- 7
n_groups <- 3
state <- vector("list", length(time))
order <- vector("list", length(time))
true <- array(NA_real_, c(n_state, n_particles, n_groups, n_time))
s <- array(0, c(n_state, n_particles, n_groups))
set.seed(1)
for (i in seq_along(time)) {
s <- s + runif(length(s))
if (i > 1 && i %% 2 == 0) {
k <- replicate(n_groups, sample(n_particles, replace = TRUE))
order[[i]] <- as.integer(k - 1L)
for (j in seq_len(n_groups)) {
s[, , j] <- s[, k[, j], j] + 1
true[, , j, seq_len(i - 1)] <- true[, k[, j], j, seq_len(i - 1)]
}
}
state[[i]] <- s
true[, , , i] <- s
}

state_arr <- array(unlist(state), dim(true))
expect_equal(test_history(time, state, order, FALSE),
list(time, state_arr))
expect_equal(test_history(time, state, order, TRUE),
list(time, true))
})
Loading