diff --git a/NAMESPACE b/NAMESPACE index c5acbbbc..0de8edc8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(dim,dust_model) S3method(print,dust_model) S3method(print,dust_model_generator) export(dust_filter_create) +export(dust_filter_last_history) export(dust_filter_rng_state) export(dust_filter_run) export(dust_model_compare_data) @@ -20,5 +21,6 @@ export(dust_model_state) export(dust_model_time) export(dust_model_update_pars) export(dust_unfilter_create) +export(dust_unfilter_last_history) export(dust_unfilter_run) useDynLib(dust2, .registration = TRUE) diff --git a/R/cpp11.R b/R/cpp11.R index b4257e2e..e860a201 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -52,20 +52,28 @@ dust2_cpu_sir_simulate <- function(ptr, r_times, r_index, grouped) { .Call(`_dust2_dust2_cpu_sir_simulate`, ptr, r_times, r_index, grouped) } -dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) { - .Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) +dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_index) { + .Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_index) } -dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, r_initial, grouped) { - .Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, r_initial, grouped) +dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, r_initial, save_history, grouped) { + .Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, r_initial, save_history, grouped) } -dust2_cpu_sir_filter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_seed) { - .Call(`_dust2_dust2_cpu_sir_filter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_seed) +dust2_cpu_sir_unfilter_last_history <- function(ptr, grouped) { + .Call(`_dust2_dust2_cpu_sir_unfilter_last_history`, ptr, grouped) } -dust2_cpu_sir_filter_run <- function(ptr, r_pars, r_initial, grouped) { - .Call(`_dust2_dust2_cpu_sir_filter_run`, ptr, r_pars, r_initial, grouped) +dust2_cpu_sir_filter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_index, r_seed) { + .Call(`_dust2_dust2_cpu_sir_filter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, r_index, r_seed) +} + +dust2_cpu_sir_filter_run <- function(ptr, r_pars, r_initial, save_history, grouped) { + .Call(`_dust2_dust2_cpu_sir_filter_run`, ptr, r_pars, r_initial, save_history, grouped) +} + +dust2_cpu_sir_filter_last_history <- function(ptr, grouped) { + .Call(`_dust2_dust2_cpu_sir_filter_last_history`, ptr, grouped) } dust2_cpu_sir_filter_rng_state <- function(ptr) { diff --git a/R/interface.R b/R/interface.R index b05cf701..a5fb7a24 100644 --- a/R/interface.R +++ b/R/interface.R @@ -27,10 +27,10 @@ dust_model <- function(name, env = parent.env(parent.frame())) { has_compare = !is.null(methods$compare_data)) if (properties$has_compare) { - methods_unfilter <- c("alloc", "run") + methods_unfilter <- c("alloc", "run", "last_history") methods$unfilter <- get_methods(methods_unfilter, sprintf("%s_unfilter", name)) - methods_filter <- c("alloc", "run", "rng_state") + methods_filter <- c("alloc", "run", "last_history", "rng_state") methods$filter <- get_methods(methods_filter, sprintf("%s_filter", name)) } @@ -376,7 +376,7 @@ dust_model_compare_data <- function(model, data) { ##' @export dust_unfilter_create <- function(generator, pars, time_start, time, data, n_particles = 1, n_groups = 0, - dt = 1) { + dt = 1, index = NULL) { check_is_dust_model_generator(generator) if (!generator$properties$has_compare) { ## This moves into something general soon? @@ -386,12 +386,13 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data, arg = "generator") } res <- generator$methods$unfilter$alloc(pars, time_start, time, dt, data, - n_particles, n_groups) + n_particles, n_groups, index) res$name <- generator$name res$n_particles <- as.integer(n_particles) res$n_groups <- as.integer(max(n_groups), 1) res$deterministic <- TRUE res$methods <- generator$methods$unfilter + res$index <- index class(res) <- "dust_unfilter" res } @@ -401,6 +402,8 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data, ##' ##' @title Run unfilter ##' +##' @inheritParams dust_filter_run +##' ##' @param unfilter A `dust_unfilter` object, created by ##' [dust_unfilter_create] ##' @@ -410,9 +413,28 @@ dust_unfilter_create <- function(generator, pars, time_start, time, data, ##' there are groups. ##' ##' @export -dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL) { +dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL, + save_history = FALSE) { + check_is_dust_unfilter(unfilter) + unfilter$methods$run(unfilter$ptr, pars, initial, save_history, + unfilter$grouped) +} + + +##' Fetch the last history created by running an unfilter. This +##' errors if the last call to [dust_unfilter_run] did not use +##' `save_history = TRUE`. +##' +##' @title Fetch last unfilter history +##' +##' @inheritParams dust_unfilter_run +##' +##' @return An array +##' +##' @export +dust_unfilter_last_history <- function(unfilter) { check_is_dust_unfilter(unfilter) - unfilter$methods$run(unfilter$ptr, pars, initial, unfilter$grouped) + unfilter$methods$last_history(unfilter$ptr, unfilter$grouped) } @@ -451,6 +473,7 @@ dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL) { ##' slowly. ##' ##' @inheritParams dust_model_create +##' @inheritParams dust_model_simulate ##' ##' @return A `dust_unfilter` object, which can be used with ##' [dust_unfilter_run] @@ -458,7 +481,7 @@ dust_unfilter_run <- function(unfilter, pars = NULL, initial = NULL) { ##' @export dust_filter_create <- function(generator, pars, time_start, time, data, n_particles, n_groups = 0, dt = 1, - seed = NULL) { + index = NULL, seed = NULL) { check_is_dust_model_generator(generator) if (!generator$properties$has_compare) { ## This moves into something general soon? @@ -468,7 +491,7 @@ dust_filter_create <- function(generator, pars, time_start, time, data, arg = "generator") } res <- generator$methods$filter$alloc(pars, time_start, time, dt, data, - n_particles, n_groups, seed) + n_particles, n_groups, index, seed) res$name <- generator$name res$n_particles <- as.integer(n_particles) res$n_groups <- as.integer(max(n_groups), 1) @@ -493,13 +516,39 @@ dust_filter_create <- function(generator, pars, time_start, time, data, ##' particle) or 3d array (state x particle x group). If not ##' provided, the model initial conditions are used. ##' +##' @param save_history Logical, indicating if the simulation history +##' should be saved while the simulation runs; this has a small +##' overhead in runtime and in memory. History (particle +##' trajectories) will be saved at each time in the filter. If the +##' filter was constructed using a non-`NULL` `index` parameter, +##' the history is restricted to these states. +##' ##' @return A vector of likelihood values, with as many elements as ##' there are groups. ##' ##' @export -dust_filter_run <- function(filter, pars = NULL, initial = NULL) { +dust_filter_run <- function(filter, pars = NULL, initial = NULL, + save_history = FALSE) { + check_is_dust_filter(filter) + filter$methods$run(filter$ptr, pars, initial, save_history, + filter$grouped) +} + + +##' Fetch the last history created by running a filter. This +##' errors if the last call to [dust_filter_run] did not use +##' `save_history = TRUE`. +##' +##' @title Fetch last filter history +##' +##' @inheritParams dust_filter_run +##' +##' @return An array +##' +##' @export +dust_filter_last_history <- function(filter) { check_is_dust_filter(filter) - filter$methods$run(filter$ptr, pars, initial, filter$grouped) + filter$methods$last_history(filter$ptr, filter$grouped) } diff --git a/_pkgdown.yml b/_pkgdown.yml index 8130aa73..f917c3f5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -29,10 +29,12 @@ reference: contents: - dust_unfilter_create - dust_unfilter_run + - dust_unfilter_last_history - subtitle: Particle filter contents: - dust_filter_create - dust_filter_run + - dust_filter_last_history - dust_filter_rng_state diff --git a/inst/include/dust2/filter.hpp b/inst/include/dust2/filter.hpp index a5250db4..17afef35 100644 --- a/inst/include/dust2/filter.hpp +++ b/inst/include/dust2/filter.hpp @@ -2,6 +2,7 @@ #include #include +#include #include namespace dust2 { @@ -19,15 +20,21 @@ class unfilter { unfilter(dust_cpu model_, real_type time_start, std::vector time, - std::vector data) : + std::vector data, + std::vector history_index) : model(model_), time_start_(time_start), time_(time), data_(data), + n_state_(model.n_state()), n_particles_(model.n_particles()), n_groups_(model.n_groups()), ll_(n_particles_ * n_groups_, 0), - ll_step_(n_particles_ * n_groups_, 0) { + ll_step_(n_particles_ * n_groups_, 0), + history_index_(history_index), + history_(history_index_.size() > 0 ? history_index_.size() : n_state_, + n_particles_, n_groups_, time_.size()), + history_is_current_(false) { const auto dt = model_.dt(); for (size_t i = 0; i < time_.size(); i++) { const auto t0 = i == 0 ? time_start_ : time_[i - 1]; @@ -36,7 +43,11 @@ class unfilter { } } - void run(bool set_initial) { + void run(bool set_initial, bool save_history) { + history_is_current_ = false; + if (save_history) { + history_.reset(); + } const auto n_times = step_.size(); model.set_time(time_start_); @@ -45,6 +56,8 @@ class unfilter { } std::fill(ll_.begin(), ll_.end(), 0); + const bool use_index = history_index_.size() > 0; + auto it_data = data_.begin(); for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) { model.run_steps(step_[i]); // just compute this at point of use? @@ -52,7 +65,17 @@ class unfilter { for (size_t j = 0; j < ll_.size(); ++j) { ll_[j] += ll_step_[j]; } + if (save_history) { + if (use_index) { + history_.add_with_index(time_[i], model.state().begin(), + history_index_.begin(), n_state_); + } else { + history_.add(time_[i], model.state().begin()); + } + } } + + history_is_current_ = save_history; } template @@ -60,15 +83,28 @@ class unfilter { std::copy(ll_.begin(), ll_.end(), iter); } + + auto& last_history() const { + return history_; + } + + bool last_history_is_current() const { + return history_is_current_; + } + private: real_type time_start_; std::vector time_; std::vector step_; std::vector data_; + size_t n_state_; size_t n_particles_; size_t n_groups_; std::vector ll_; std::vector ll_step_; + std::vector history_index_; + history history_; + bool history_is_current_; }; template @@ -85,16 +121,22 @@ class filter { real_type time_start, std::vector time, std::vector data, + std::vector history_index, const std::vector& seed) : model(model_), time_start_(time_start), time_(time), data_(data), + n_state_(model.n_state()), n_particles_(model.n_particles()), n_groups_(model.n_groups()), rng_(n_groups_, seed, false), ll_(n_groups_ * n_particles_, 0), - ll_step_(n_groups_ * n_particles_, 0) { + ll_step_(n_groups_ * n_particles_, 0), + history_index_(history_index), + history_(history_index_.size() > 0 ? history_index_.size() : n_state_, + n_particles_, n_groups_, time_.size()), + history_is_current_(false) { // TODO: duplicated with the above, can be done generically though // it's not a lot of code. const auto dt = model_.dt(); @@ -105,7 +147,11 @@ class filter { } } - void run(bool set_initial) { + void run(bool set_initial, bool save_history) { + history_is_current_ = false; + if (save_history) { + history_.reset(); + } const auto n_times = step_.size(); model.set_time(time_start_); @@ -118,6 +164,8 @@ class filter { // probably use that vector instead. std::vector index(n_particles_ * n_groups_); + const bool use_index = history_index_.size() > 0; + auto it_data = data_.begin(); for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) { model.run_steps(step_[i]); @@ -144,9 +192,18 @@ class filter { model.reorder(index.begin()); - // save trajectories (perhaps) + if (save_history) { + if (use_index) { + history_.add_with_index(time_[i], model.state().begin(), index.begin(), + history_index_.begin(), n_state_); + } else { + history_.add(time_[i], model.state().begin(), index.begin()); + } + } // save snapshots (perhaps) } + + history_is_current_ = save_history; } template @@ -154,6 +211,14 @@ class filter { std::copy_n(ll_.begin(), n_groups_, it); } + auto& last_history() const { + return history_; + } + + bool last_history_is_current() const { + return history_is_current_; + } + auto rng_state() { // TODO: should be const, error in mcstate2 return rng_.export_state(); } @@ -163,11 +228,15 @@ class filter { std::vector time_; std::vector step_; std::vector data_; + size_t n_state_; size_t n_particles_; size_t n_groups_; mcstate::random::prng rng_; std::vector ll_; std::vector ll_step_; + std::vector history_index_; + history history_; + bool history_is_current_; }; } diff --git a/inst/include/dust2/history.hpp b/inst/include/dust2/history.hpp index 5f5cc5f4..6b0e71de 100644 --- a/inst/include/dust2/history.hpp +++ b/inst/include/dust2/history.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace dust2 { @@ -22,13 +23,25 @@ class history { times_(n_times_), state_(len_state_ * n_times_), order_(len_order_ * n_times_), - reorder_(n_times) { + reorder_(n_times), + dims_({n_state_, n_particles_, n_groups_, position_}) { } - void resize(size_t n_times) { - n_times_ = n_times; - state_.resize(len_state_ * n_times_); - order_.resize(len_order_ * n_times_); + void resize_state(size_t n_state) { + if (n_state_ != n_state) { + n_state_ = n_state; + len_state_ = (n_state_ * n_particles_ * n_groups_); + state_.resize(len_state_ * n_times_); + } + reset(); + } + + void resize_time(size_t n_times) { + if (n_times_ != n_times) { + n_times_ = n_times; + state_.resize(len_state_ * n_times_); + order_.resize(len_order_ * n_times_); + } reset(); } @@ -39,37 +52,29 @@ class history { template void add(real_type time, IterReal iter_state) { copy_state_(iter_state); - times_[position_] = time; - reorder_[position_] = false; - position_++; + update_position(time, false); } template void add(real_type time, IterReal iter_state, IterSize iter_order) { copy_state_(iter_state); copy_order_(iter_order); - times_[position_] = time; - reorder_[position_] = true; - position_++; + update_position(time, true); } template void add_with_index(real_type time, IterReal iter_state, IterSize iter_index, size_t n_state_total) { copy_state_with_index_(iter_state, iter_index, n_state_total); - times_[position_] = time; - reorder_[position_] = false; - position_++; + update_position(time, false); } template void add_with_index(real_type time, IterReal iter_state, IterSize iter_order, IterSize iter_index, size_t n_state_total) { - copy_state_with_index(iter_state, iter_index, n_state_total); + copy_state_with_index_(iter_state, iter_index, n_state_total); copy_order_(iter_order); - times_[position_] = time; - reorder_[position_] = true; - position_++; + update_position(time, true); } // These allow a consumer to allocate the right size structures for @@ -82,13 +87,21 @@ class history { return position_ * len_state_; } + auto size_order() const { + return position_ * len_order_; + } + + auto& dims() const { + return dims_; + } + template - void export_time(Iter iter) { + void export_time(Iter iter) const { std::copy_n(times_.begin(), position_, iter); } template - void export_state(Iter iter, bool reorder) { + void export_state(Iter iter, bool reorder) const { reorder = reorder && n_particles_ > 1 && position_ > 0 && std::any_of(reorder_.begin(), reorder_.end(), [](auto v) { return v; }); if (reorder) { @@ -121,6 +134,11 @@ class history { } } + template + void export_order(Iter iter) const { + std::copy_n(order_.begin(), position_ * len_order_, iter); + } + private: size_t n_state_; size_t n_particles_; @@ -133,6 +151,7 @@ class history { std::vector state_; std::vector order_; std::vector reorder_; + std::array dims_; // Reference implementation for this is mcstate:::history_single and // mcstate::history_multiple @@ -141,7 +160,7 @@ class history { typename std::vector::const_iterator iter_order, bool reorder, Iter iter_dest, - typename std::vector::iterator index) { + typename std::vector::iterator index) const { for (size_t i = 0; i < n_particles_; ++i) { std::copy_n(iter_state + *(index + i) * n_state_, n_state_, @@ -153,6 +172,13 @@ class history { } } + void update_position(real_type time, bool reorder) { + times_[position_] = time; + reorder_[position_] = reorder; + position_++; + dims_[3] = position_; + } + template void copy_state_(IterReal iter) { std::copy_n(iter, len_state_, state_.begin() + len_state_ * position_); diff --git a/inst/include/dust2/r/filter.hpp b/inst/include/dust2/r/filter.hpp index d45ff8bf..a9106ea1 100644 --- a/inst/include/dust2/r/filter.hpp +++ b/inst/include/dust2/r/filter.hpp @@ -14,7 +14,8 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, - cpp11::sexp r_n_groups) { + cpp11::sexp r_n_groups, + cpp11::sexp r_index) { using rng_state_type = typename T::rng_state_type; auto n_particles = to_size(r_n_particles, "n_particles"); @@ -40,8 +41,9 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars, // going to feel weirder overall. const auto model = dust2::dust_cpu(shared, internal, time_start, dt, n_particles, seed, deterministic); + const auto index = check_index(r_index, model.n_state(), "index"); - auto obj = new unfilter(model, time_start, time, data); + auto obj = new unfilter(model, time_start, time, data, index); cpp11::external_pointer> ptr(obj, true, false); cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state()); @@ -61,7 +63,8 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars, template cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - cpp11::sexp r_initial, bool grouped) { + cpp11::sexp r_initial, bool save_history, + bool grouped) { auto *obj = cpp11::as_cpp>>(ptr).get(); if (r_pars != R_NilValue) { @@ -70,7 +73,7 @@ cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, if (r_initial != R_NilValue) { set_state(obj->model, r_initial, grouped); } - obj->run(r_initial == R_NilValue); + obj->run(r_initial == R_NilValue, save_history); const auto n_groups = obj->model.n_groups(); const auto n_particles = obj->model.n_particles(); @@ -82,6 +85,34 @@ cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, return ret; } +template +cpp11::sexp dust2_cpu_unfilter_last_history(cpp11::sexp ptr, bool grouped) { + auto *obj = + cpp11::as_cpp>>(ptr).get(); + if (!obj->last_history_is_current()) { + cpp11::stop("History is not current"); + } + + constexpr bool reorder = false; // never needed + + const auto& history = obj->last_history(); + const auto& dims = history.dims(); + // Could use destructured bind here in recent C++? + const auto n_state = dims[0]; + const auto n_particles = dims[1]; + const auto n_groups = dims[2]; + const auto n_times = dims[3]; + const auto len = n_state * n_particles * n_groups * n_times; + cpp11::sexp ret = cpp11::writable::doubles(len); + history.export_state(REAL(ret), reorder); + if (grouped) { + set_array_dims(ret, {n_state, n_particles, n_groups, n_times}); + } else { + set_array_dims(ret, {n_state, n_particles * n_groups, n_times}); + } + return ret; +} + template cpp11::sexp dust2_cpu_filter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, @@ -90,6 +121,7 @@ cpp11::sexp dust2_cpu_filter_alloc(cpp11::list r_pars, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, + cpp11::sexp r_index, cpp11::sexp r_seed) { using rng_state_type = typename T::rng_state_type; using rng_seed_type = std::vector; @@ -149,7 +181,9 @@ cpp11::sexp dust2_cpu_filter_alloc(cpp11::list r_pars, const auto model = dust2::dust_cpu(shared, internal, time_start, dt, n_particles, seed_model, deterministic); - auto obj = new filter(model, time_start, time, data, seed_filter); + const auto index = check_index(r_index, model.n_state(), "index"); + + auto obj = new filter(model, time_start, time, data, index, seed_filter); cpp11::external_pointer> ptr(obj, true, false); cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state()); @@ -169,7 +203,8 @@ cpp11::sexp dust2_cpu_filter_alloc(cpp11::list r_pars, template cpp11::sexp dust2_cpu_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - cpp11::sexp r_initial, bool grouped) { + cpp11::sexp r_initial, bool save_history, + bool grouped) { auto *obj = cpp11::as_cpp>>(ptr).get(); if (r_pars != R_NilValue) { @@ -178,13 +213,43 @@ cpp11::sexp dust2_cpu_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, if (r_initial != R_NilValue) { set_state(obj->model, r_initial, grouped); } - obj->run(r_initial == R_NilValue); + obj->run(r_initial == R_NilValue, save_history); cpp11::writable::doubles ret(obj->model.n_groups()); obj->last_log_likelihood(REAL(ret)); return ret; } +// Can collapse with above +template +cpp11::sexp dust2_cpu_filter_last_history(cpp11::sexp ptr, bool grouped) { + auto *obj = + cpp11::as_cpp>>(ptr).get(); + if (!obj->last_history_is_current()) { + cpp11::stop("History is not current"); + } + // We might relax this later, but will require some tools to work + // with the output, really. + constexpr bool reorder = true; + + const auto& history = obj->last_history(); + const auto& dims = history.dims(); + // Could use destructured bind here in recent C++? + const auto n_state = dims[0]; + const auto n_particles = dims[1]; + const auto n_groups = dims[2]; + const auto n_times = dims[3]; + const auto len = n_state * n_particles * n_groups * n_times; + cpp11::sexp ret = cpp11::writable::doubles(len); + history.export_state(REAL(ret), reorder); + if (grouped) { + set_array_dims(ret, {n_state, n_particles, n_groups, n_times}); + } else { + set_array_dims(ret, {n_state, n_particles * n_groups, n_times}); + } + return ret; +} + template cpp11::sexp dust2_cpu_filter_rng_state(cpp11::sexp ptr) { auto *obj = cpp11::as_cpp>>(ptr).get(); diff --git a/man/dust_filter_create.Rd b/man/dust_filter_create.Rd index f6fe4082..30778e45 100644 --- a/man/dust_filter_create.Rd +++ b/man/dust_filter_create.Rd @@ -13,6 +13,7 @@ dust_filter_create( n_particles, n_groups = 0, dt = 1, + index = NULL, seed = NULL ) } @@ -51,6 +52,12 @@ slowly.} \item{dt}{The time step for the model, defaults to 1} +\item{index}{An optional index of states to extract. If given, +then we subset the model state on return. You can use this to +return fewer model states than the model ran with, to reorder +states, or to name them on exit (names present on the index will +be copied into the rownames of the returned array).} + \item{seed}{Optionally, a seed. Otherwise we respond to R's RNG seed on initialisation.} } diff --git a/man/dust_filter_last_history.Rd b/man/dust_filter_last_history.Rd new file mode 100644 index 00000000..cecffbd7 --- /dev/null +++ b/man/dust_filter_last_history.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/interface.R +\name{dust_filter_last_history} +\alias{dust_filter_last_history} +\title{Fetch last filter history} +\usage{ +dust_filter_last_history(filter) +} +\arguments{ +\item{filter}{A \code{dust_filter} object, created by +\link{dust_filter_create}} +} +\value{ +An array +} +\description{ +Fetch the last history created by running a filter. This +errors if the last call to \link{dust_filter_run} did not use +\code{save_history = TRUE}. +} diff --git a/man/dust_filter_run.Rd b/man/dust_filter_run.Rd index 66478762..e66f41fb 100644 --- a/man/dust_filter_run.Rd +++ b/man/dust_filter_run.Rd @@ -4,7 +4,7 @@ \alias{dust_filter_run} \title{Run particle filter} \usage{ -dust_filter_run(filter, pars = NULL, initial = NULL) +dust_filter_run(filter, pars = NULL, initial = NULL, save_history = FALSE) } \arguments{ \item{filter}{A \code{dust_filter} object, created by @@ -16,6 +16,13 @@ provided, parameters are not updated} \item{initial}{Optional initial conditions, as a matrix (state x particle) or 3d array (state x particle x group). If not provided, the model initial conditions are used.} + +\item{save_history}{Logical, indicating if the simulation history +should be saved while the simulation runs; this has a small +overhead in runtime and in memory. History (particle +trajectories) will be saved at each time in the filter. If the +filter was constructed using a non-\code{NULL} \code{index} parameter, +the history is restricted to these states.} } \value{ A vector of likelihood values, with as many elements as diff --git a/man/dust_unfilter_create.Rd b/man/dust_unfilter_create.Rd index e1886719..711f9860 100644 --- a/man/dust_unfilter_create.Rd +++ b/man/dust_unfilter_create.Rd @@ -12,7 +12,8 @@ dust_unfilter_create( data, n_particles = 1, n_groups = 0, - dt = 1 + dt = 1, + index = NULL ) } \arguments{ @@ -50,6 +51,12 @@ initial conditions then you would see different likelihoods.} \item{n_groups}{Optionally, the number of parameter groups} \item{dt}{The time step for the model, defaults to 1} + +\item{index}{An optional index of states to extract. If given, +then we subset the model state on return. You can use this to +return fewer model states than the model ran with, to reorder +states, or to name them on exit (names present on the index will +be copied into the rownames of the returned array).} } \value{ A \code{dust_unfilter} object, which can be used with diff --git a/man/dust_unfilter_last_history.Rd b/man/dust_unfilter_last_history.Rd new file mode 100644 index 00000000..4280f737 --- /dev/null +++ b/man/dust_unfilter_last_history.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/interface.R +\name{dust_unfilter_last_history} +\alias{dust_unfilter_last_history} +\title{Fetch last unfilter history} +\usage{ +dust_unfilter_last_history(unfilter) +} +\arguments{ +\item{unfilter}{A \code{dust_unfilter} object, created by +\link{dust_unfilter_create}} +} +\value{ +An array +} +\description{ +Fetch the last history created by running an unfilter. This +errors if the last call to \link{dust_unfilter_run} did not use +\code{save_history = TRUE}. +} diff --git a/man/dust_unfilter_run.Rd b/man/dust_unfilter_run.Rd index 928c287d..1f3d8d57 100644 --- a/man/dust_unfilter_run.Rd +++ b/man/dust_unfilter_run.Rd @@ -4,7 +4,7 @@ \alias{dust_unfilter_run} \title{Run unfilter} \usage{ -dust_unfilter_run(unfilter, pars = NULL, initial = NULL) +dust_unfilter_run(unfilter, pars = NULL, initial = NULL, save_history = FALSE) } \arguments{ \item{unfilter}{A \code{dust_unfilter} object, created by @@ -16,6 +16,13 @@ provided, parameters are not updated} \item{initial}{Optional initial conditions, as a matrix (state x particle) or 3d array (state x particle x group). If not provided, the model initial conditions are used.} + +\item{save_history}{Logical, indicating if the simulation history +should be saved while the simulation runs; this has a small +overhead in runtime and in memory. History (particle +trajectories) will be saved at each time in the filter. If the +filter was constructed using a non-\code{NULL} \code{index} parameter, +the history is restricted to these states.} } \value{ A vector of likelihood values, with as many elements as diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 8ebbeb0f..7cf81195 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -97,31 +97,45 @@ extern "C" SEXP _dust2_dust2_cpu_sir_simulate(SEXP ptr, SEXP r_times, SEXP r_ind END_CPP11 } // sir.cpp -SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups); -extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups) { +SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_index); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups, SEXP r_index) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_cpu_sir_unfilter_alloc(cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_time_start), cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_dt), cpp11::as_cpp>(r_data), cpp11::as_cpp>(r_n_particles), cpp11::as_cpp>(r_n_groups))); + return cpp11::as_sexp(dust2_cpu_sir_unfilter_alloc(cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_time_start), cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_dt), cpp11::as_cpp>(r_data), cpp11::as_cpp>(r_n_particles), cpp11::as_cpp>(r_n_groups), cpp11::as_cpp>(r_index))); END_CPP11 } // sir.cpp -SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool grouped); -extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP grouped) { +SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool save_history, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP save_history, SEXP grouped) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_initial), cpp11::as_cpp>(grouped))); + return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_initial), cpp11::as_cpp>(save_history), cpp11::as_cpp>(grouped))); END_CPP11 } // sir.cpp -SEXP dust2_cpu_sir_filter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed); -extern "C" SEXP _dust2_dust2_cpu_sir_filter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed) { +SEXP dust2_cpu_sir_unfilter_last_history(cpp11::sexp ptr, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_last_history(SEXP ptr, SEXP grouped) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_cpu_sir_filter_alloc(cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_time_start), cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_dt), cpp11::as_cpp>(r_data), cpp11::as_cpp>(r_n_particles), cpp11::as_cpp>(r_n_groups), cpp11::as_cpp>(r_seed))); + return cpp11::as_sexp(dust2_cpu_sir_unfilter_last_history(cpp11::as_cpp>(ptr), cpp11::as_cpp>(grouped))); END_CPP11 } // sir.cpp -SEXP dust2_cpu_sir_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool grouped); -extern "C" SEXP _dust2_dust2_cpu_sir_filter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP grouped) { +SEXP dust2_cpu_sir_filter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_index, cpp11::sexp r_seed); +extern "C" SEXP _dust2_dust2_cpu_sir_filter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups, SEXP r_index, SEXP r_seed) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_cpu_sir_filter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_initial), cpp11::as_cpp>(grouped))); + return cpp11::as_sexp(dust2_cpu_sir_filter_alloc(cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_time_start), cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_dt), cpp11::as_cpp>(r_data), cpp11::as_cpp>(r_n_particles), cpp11::as_cpp>(r_n_groups), cpp11::as_cpp>(r_index), cpp11::as_cpp>(r_seed))); + END_CPP11 +} +// sir.cpp +SEXP dust2_cpu_sir_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool save_history, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_filter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP save_history, SEXP grouped) { + BEGIN_CPP11 + return cpp11::as_sexp(dust2_cpu_sir_filter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_initial), cpp11::as_cpp>(save_history), cpp11::as_cpp>(grouped))); + END_CPP11 +} +// sir.cpp +SEXP dust2_cpu_sir_filter_last_history(cpp11::sexp ptr, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_filter_last_history(SEXP ptr, SEXP grouped) { + BEGIN_CPP11 + return cpp11::as_sexp(dust2_cpu_sir_filter_last_history(cpp11::as_cpp>(ptr), cpp11::as_cpp>(grouped))); END_CPP11 } // sir.cpp @@ -239,39 +253,41 @@ extern "C" SEXP _dust2_dust2_cpu_walk_simulate(SEXP ptr, SEXP r_times, SEXP r_in extern "C" { static const R_CallMethodDef CallEntries[] = { - {"_dust2_dust2_cpu_sir_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_alloc, 7}, - {"_dust2_dust2_cpu_sir_compare_data", (DL_FUNC) &_dust2_dust2_cpu_sir_compare_data, 3}, - {"_dust2_dust2_cpu_sir_filter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_alloc, 8}, - {"_dust2_dust2_cpu_sir_filter_rng_state", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_rng_state, 1}, - {"_dust2_dust2_cpu_sir_filter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_run, 4}, - {"_dust2_dust2_cpu_sir_reorder", (DL_FUNC) &_dust2_dust2_cpu_sir_reorder, 2}, - {"_dust2_dust2_cpu_sir_rng_state", (DL_FUNC) &_dust2_dust2_cpu_sir_rng_state, 1}, - {"_dust2_dust2_cpu_sir_run_steps", (DL_FUNC) &_dust2_dust2_cpu_sir_run_steps, 2}, - {"_dust2_dust2_cpu_sir_run_to_time", (DL_FUNC) &_dust2_dust2_cpu_sir_run_to_time, 2}, - {"_dust2_dust2_cpu_sir_set_state", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state, 3}, - {"_dust2_dust2_cpu_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state_initial, 1}, - {"_dust2_dust2_cpu_sir_set_time", (DL_FUNC) &_dust2_dust2_cpu_sir_set_time, 2}, - {"_dust2_dust2_cpu_sir_simulate", (DL_FUNC) &_dust2_dust2_cpu_sir_simulate, 4}, - {"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2}, - {"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1}, - {"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 7}, - {"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 4}, - {"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3}, - {"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7}, - {"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2}, - {"_dust2_dust2_cpu_walk_rng_state", (DL_FUNC) &_dust2_dust2_cpu_walk_rng_state, 1}, - {"_dust2_dust2_cpu_walk_run_steps", (DL_FUNC) &_dust2_dust2_cpu_walk_run_steps, 2}, - {"_dust2_dust2_cpu_walk_run_to_time", (DL_FUNC) &_dust2_dust2_cpu_walk_run_to_time, 2}, - {"_dust2_dust2_cpu_walk_set_state", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state, 3}, - {"_dust2_dust2_cpu_walk_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state_initial, 1}, - {"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2}, - {"_dust2_dust2_cpu_walk_simulate", (DL_FUNC) &_dust2_dust2_cpu_walk_simulate, 4}, - {"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2}, - {"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1}, - {"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3}, - {"_dust2_test_history", (DL_FUNC) &_dust2_test_history, 4}, - {"_dust2_test_resample_weight", (DL_FUNC) &_dust2_test_resample_weight, 2}, - {"_dust2_test_scale_log_weights", (DL_FUNC) &_dust2_test_scale_log_weights, 1}, + {"_dust2_dust2_cpu_sir_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_alloc, 7}, + {"_dust2_dust2_cpu_sir_compare_data", (DL_FUNC) &_dust2_dust2_cpu_sir_compare_data, 3}, + {"_dust2_dust2_cpu_sir_filter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_alloc, 9}, + {"_dust2_dust2_cpu_sir_filter_last_history", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_last_history, 2}, + {"_dust2_dust2_cpu_sir_filter_rng_state", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_rng_state, 1}, + {"_dust2_dust2_cpu_sir_filter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_filter_run, 5}, + {"_dust2_dust2_cpu_sir_reorder", (DL_FUNC) &_dust2_dust2_cpu_sir_reorder, 2}, + {"_dust2_dust2_cpu_sir_rng_state", (DL_FUNC) &_dust2_dust2_cpu_sir_rng_state, 1}, + {"_dust2_dust2_cpu_sir_run_steps", (DL_FUNC) &_dust2_dust2_cpu_sir_run_steps, 2}, + {"_dust2_dust2_cpu_sir_run_to_time", (DL_FUNC) &_dust2_dust2_cpu_sir_run_to_time, 2}, + {"_dust2_dust2_cpu_sir_set_state", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state, 3}, + {"_dust2_dust2_cpu_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state_initial, 1}, + {"_dust2_dust2_cpu_sir_set_time", (DL_FUNC) &_dust2_dust2_cpu_sir_set_time, 2}, + {"_dust2_dust2_cpu_sir_simulate", (DL_FUNC) &_dust2_dust2_cpu_sir_simulate, 4}, + {"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2}, + {"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1}, + {"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 8}, + {"_dust2_dust2_cpu_sir_unfilter_last_history", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_last_history, 2}, + {"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 5}, + {"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3}, + {"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7}, + {"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2}, + {"_dust2_dust2_cpu_walk_rng_state", (DL_FUNC) &_dust2_dust2_cpu_walk_rng_state, 1}, + {"_dust2_dust2_cpu_walk_run_steps", (DL_FUNC) &_dust2_dust2_cpu_walk_run_steps, 2}, + {"_dust2_dust2_cpu_walk_run_to_time", (DL_FUNC) &_dust2_dust2_cpu_walk_run_to_time, 2}, + {"_dust2_dust2_cpu_walk_set_state", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state, 3}, + {"_dust2_dust2_cpu_walk_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state_initial, 1}, + {"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2}, + {"_dust2_dust2_cpu_walk_simulate", (DL_FUNC) &_dust2_dust2_cpu_walk_simulate, 4}, + {"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2}, + {"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1}, + {"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3}, + {"_dust2_test_history", (DL_FUNC) &_dust2_test_history, 4}, + {"_dust2_test_resample_weight", (DL_FUNC) &_dust2_test_resample_weight, 2}, + {"_dust2_test_scale_log_weights", (DL_FUNC) &_dust2_test_scale_log_weights, 1}, {NULL, NULL, 0} }; } diff --git a/src/sir.cpp b/src/sir.cpp index 7c69660c..c3d02375 100644 --- a/src/sir.cpp +++ b/src/sir.cpp @@ -92,16 +92,24 @@ SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, - cpp11::sexp r_n_groups) { + cpp11::sexp r_n_groups, + cpp11::sexp r_index) { return dust2::r::dust2_cpu_unfilter_alloc(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, - r_n_groups); + r_n_groups, r_index); } [[cpp11::register]] SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - cpp11::sexp r_initial, bool grouped) { - return dust2::r::dust2_cpu_unfilter_run(ptr, r_pars, r_initial, grouped); + cpp11::sexp r_initial, bool save_history, + bool grouped) { + return dust2::r::dust2_cpu_unfilter_run(ptr, r_pars, r_initial, + save_history, grouped); +} + +[[cpp11::register]] +SEXP dust2_cpu_sir_unfilter_last_history(cpp11::sexp ptr, bool grouped) { + return dust2::r::dust2_cpu_unfilter_last_history(ptr, grouped); } [[cpp11::register]] @@ -112,17 +120,25 @@ SEXP dust2_cpu_sir_filter_alloc(cpp11::list r_pars, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, + cpp11::sexp r_index, cpp11::sexp r_seed) { return dust2::r::dust2_cpu_filter_alloc(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups, - r_seed); + r_index, r_seed); } [[cpp11::register]] SEXP dust2_cpu_sir_filter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - cpp11::sexp r_initial, bool grouped) { - return dust2::r::dust2_cpu_filter_run(ptr, r_pars, r_initial, grouped); + cpp11::sexp r_initial, bool save_history, + bool grouped) { + return dust2::r::dust2_cpu_filter_run(ptr, r_pars, r_initial, + save_history, grouped); +} + +[[cpp11::register]] +SEXP dust2_cpu_sir_filter_last_history(cpp11::sexp ptr, bool grouped) { + return dust2::r::dust2_cpu_filter_last_history(ptr, grouped); } [[cpp11::register]] diff --git a/tests/testthat/helper-dust.R b/tests/testthat/helper-dust.R index b6e7cf66..ed93d3eb 100644 --- a/tests/testthat/helper-dust.R +++ b/tests/testthat/helper-dust.R @@ -10,8 +10,10 @@ sir_filter_manual <- function(pars, time_start, time, dt, data, n_particles, obj <- dust_model_create(sir(), pars, n_particles, time = time_start, dt = dt, seed = seed) n_steps <- round((time - c(time_start, time[-length(time)])) / dt) + n_state <- nrow(dust_model_state(obj)) + n_time <- length(time) - function(pars, initial = NULL) { + function(pars, initial = NULL, save_history = FALSE) { if (!is.null(pars)) { dust_model_update_pars(obj, pars) } @@ -22,6 +24,7 @@ sir_filter_manual <- function(pars, time_start, time, dt, data, n_particles, dust_model_set_state(obj, initial) } ll <- 0 + history <- array(NA_real_, c(n_state, n_particles, n_time)) for (i in seq_along(time)) { dust_model_run_steps(obj, n_steps[[i]]) tmp <- dust_model_compare_data(obj, data[[i]]) @@ -30,8 +33,12 @@ sir_filter_manual <- function(pars, time_start, time, dt, data, n_particles, u <- r$random_real(1) k <- test_resample_weight(w, u) + 1L state <- dust_model_state(obj) + if (save_history) { + history[, , i] <- state + history <- history[, k, , drop = FALSE] + } dust_model_set_state(obj, state[, k]) } - ll + list(log_likelihood = ll, history = if (save_history) history) } } diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R index 659a492b..f0289cb2 100644 --- a/tests/testthat/test-filter.R +++ b/tests/testthat/test-filter.R @@ -31,6 +31,57 @@ test_that("can run an unfilter", { }) +test_that("can get unfilter history", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + + obj <- dust_unfilter_create(sir(), pars, time_start, time, data) + dust_unfilter_run(obj) + expect_error( + dust_unfilter_last_history(obj), + "History is not current") + dust_unfilter_run(obj, save_history = TRUE) + h <- dust_unfilter_last_history(obj) + expect_equal(dust_unfilter_last_history(obj), h) + dust_unfilter_run(obj, save_history = FALSE) + expect_error( + dust_unfilter_last_history(obj), + "History is not current") + + m <- dust_model_create(sir(), pars, time = time_start, n_particles = 1, + deterministic = TRUE) + dust_model_set_state_initial(m) + cmp <- dust_model_simulate(m, time) + expect_equal(h, cmp) +}) + + +test_that("can get partial unfilter history", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + + obj1 <- dust_unfilter_create(sir(), pars, time_start, time, data) + obj2 <- dust_unfilter_create(sir(), pars, time_start, time, data, + index = c(2, 4)) + expect_equal(dust_unfilter_run(obj1, save_history = TRUE), + dust_unfilter_run(obj2, save_history = TRUE)) + + h1 <- dust_unfilter_last_history(obj1) + h2 <- dust_unfilter_last_history(obj2) + expect_equal(dim(h1), c(5, 1, 4)) + expect_equal(dim(h2), c(2, 1, 4)) + expect_equal(h2, h1[c(2, 4), , , drop = FALSE]) +}) + + test_that("can run an unfilter with manually set state", { pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) state <- matrix(c(1000 - 17, 17, 0, 0, 0), ncol = 1) @@ -154,6 +205,40 @@ test_that("can run replicated structured unfilter", { }) +test_that("can save history from structured unfilter", { + pars <- list( + list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6), + list(beta = 0.2, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)) + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) { + lapply(seq_len(2), function(j) list(incidence = 2 * (i - 1) + j)) + }) + data1 <- lapply(data, function(x) x[[1]]) + data2 <- lapply(data, function(x) x[[2]]) + dt <- 1 + obj <- dust_unfilter_create(sir(), pars, time_start, time, data, + n_groups = 2) + obj1 <- dust_unfilter_create(sir(), pars[[1]], time_start, time, data1) + obj2 <- dust_unfilter_create(sir(), pars[[2]], time_start, time, data2) + + ll <- dust_unfilter_run(obj, save_history = TRUE) + ll1 <- dust_unfilter_run(obj1, save_history = TRUE) + ll2 <- dust_unfilter_run(obj2, save_history = TRUE) + + expect_equal(ll, c(ll1, ll2)) + h <- dust_unfilter_last_history(obj) + h1 <- dust_unfilter_last_history(obj1) + h2 <- dust_unfilter_last_history(obj2) + + expect_equal(dim(h), c(5, 1, 2, 4)) + expect_equal(dim(h1), c(5, 1, 4)) + expect_equal(dim(h2), c(5, 1, 4)) + expect_equal(array(h[, , 1, ], dim(h1)), h1) + expect_equal(array(h[, , 2, ], dim(h2)), h2) +}) + + test_that("can run particle filter", { pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) @@ -170,7 +255,67 @@ test_that("can run particle filter", { cmp_filter <- sir_filter_manual( pars, time_start, time, dt, data, n_particles, seed) - expect_equal(res, replicate(20, cmp_filter(NULL))) + expect_equal(res, replicate(20, cmp_filter(NULL)$log_likelihood)) +}) + + +test_that("can run particle filter and save history", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + n_particles <- 100 + seed <- 42 + + obj <- dust_filter_create(sir(), pars, time_start, time, data, + n_particles = n_particles, seed = seed) + res1 <- dust_filter_run(obj) + expect_error(dust_filter_last_history(obj), "History is not current") + res2 <- dust_filter_run(obj, save_history = TRUE) + h2 <- dust_filter_last_history(obj) + expect_equal(dim(h2), c(5, 100, 4)) + res3 <- dust_filter_run(obj) + expect_error(dust_filter_last_history(obj), "History is not current") + + cmp_filter <- sir_filter_manual( + pars, time_start, time, dt, data, n_particles, seed) + cmp1 <- cmp_filter(NULL) + cmp2 <- cmp_filter(NULL, save_history = TRUE) + cmp3 <- cmp_filter(NULL) + + expect_equal(res1, cmp1$log_likelihood) + expect_equal(res2, cmp2$log_likelihood) + expect_equal(res3, cmp3$log_likelihood) + expect_equal(h2, cmp2$history) +}) + + +test_that("can get partial unfilter history", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + n_particles <- 100 + seed <- 42 + + obj1 <- dust_filter_create(sir(), pars, time_start, time, data, + n_particles = n_particles, seed = seed) + obj2 <- dust_filter_create(sir(), pars, time_start, time, data, + n_particles = n_particles, seed = seed, + index = c(2, 4)) + + expect_equal(dust_filter_run(obj1, save_history = TRUE), + dust_filter_run(obj2, save_history = TRUE)) + + h1 <- dust_filter_last_history(obj1) + h2 <- dust_filter_last_history(obj2) + expect_equal(dim(h1), c(5, 100, 4)) + expect_equal(dim(h2), c(2, 100, 4)) + expect_equal(h2, h1[c(2, 4), , , drop = FALSE]) }) @@ -198,14 +343,14 @@ test_that("can run a nested particle filter and get the same result", { s <- dust_filter_rng_state(obj) expect_equal(s, r) - res <- replicate(20, dust_filter_run(obj)) + res <- replicate(20, dust_filter_run(obj, save_history = TRUE)) ## now compare: data1 <- lapply(data, "[[", 1) obj1 <- dust_filter_create(sir(), pars[[1]], time_start, time, data1, n_particles = n_particles, seed = seed) s1 <- dust_filter_rng_state(obj1) - res1 <- replicate(20, dust_filter_run(obj1)) + res1 <- replicate(20, dust_filter_run(obj1, save_history = TRUE)) expect_equal(res1, res[1, ]) expect_equal(s1, s[1:3232]) @@ -214,9 +359,14 @@ test_that("can run a nested particle filter and get the same result", { obj2 <- dust_filter_create(sir(), pars[[2]], time_start, time, data2, n_particles = n_particles, seed = seed2) s2 <- dust_filter_rng_state(obj2) - res2 <- replicate(20, dust_filter_run(obj2)) + res2 <- replicate(20, dust_filter_run(obj2, save_history = TRUE)) expect_equal(res2, res[2, ]) expect_equal(s2, s[3233:6464]) + + h <- dust_filter_last_history(obj) + expect_equal(dim(h), c(5, 100, 2, 4)) + expect_equal(h[, , 1, ], dust_filter_last_history(obj1)) + expect_equal(h[, , 2, ], dust_filter_last_history(obj2)) }) @@ -259,5 +409,5 @@ test_that("can run particle filter with manual initial state", { cmp_filter <- sir_filter_manual( pars, time_start, time, dt, data, n_particles, seed) - expect_equal(res, replicate(20, cmp_filter(NULL, state))) + expect_equal(res, replicate(20, cmp_filter(NULL, state)$log_likelihood)) })