From d236192937b6293937abc43cd4cbe1b77ebadede Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 16 May 2024 08:34:21 +0100 Subject: [PATCH] Allow setting initial state into filter --- R/cpp11.R | 4 +-- inst/include/dust2/filter.hpp | 6 +++-- inst/include/dust2/r/cpu.hpp | 29 +------------------- inst/include/dust2/r/filter.hpp | 7 +++-- inst/include/dust2/r/helpers.hpp | 33 +++++++++++++++++++++++ src/cpp11.cpp | 8 +++--- src/sir.cpp | 4 +-- tests/testthat/test-filter.R | 46 ++++++++++++++++++++++++++------ 8 files changed, 89 insertions(+), 48 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index f082dc1c..7f8f7ad3 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -40,8 +40,8 @@ dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_d .Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) } -dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, grouped) { - .Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, grouped) +dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, r_initial, grouped) { + .Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, r_initial, grouped) } dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) { diff --git a/inst/include/dust2/filter.hpp b/inst/include/dust2/filter.hpp index 35d7c178..c0d963cf 100644 --- a/inst/include/dust2/filter.hpp +++ b/inst/include/dust2/filter.hpp @@ -34,11 +34,13 @@ class unfilter { } } - void run() { + void run(bool set_initial) { const auto n_times = step_.size(); model.set_time(time_start_); - model.set_state_initial(); + if (set_initial) { + model.set_state_initial(); + } std::fill(ll_.begin(), ll_.end(), 0); auto it_data = data_.begin(); diff --git a/inst/include/dust2/r/cpu.hpp b/inst/include/dust2/r/cpu.hpp index e9cc9eb8..a0ec70cf 100644 --- a/inst/include/dust2/r/cpu.hpp +++ b/inst/include/dust2/r/cpu.hpp @@ -100,34 +100,7 @@ SEXP dust2_cpu_set_state_initial(cpp11::sexp ptr) { template SEXP dust2_cpu_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool grouped) { auto *obj = cpp11::as_cpp>>(ptr).get(); - // Suppose that we have a n_state x n_particles x n_groups grouped - // system, we then require that we have a state array with rank 3; - // for an ungrouped system this will be rank 2 array. - auto dim = cpp11::as_cpp(r_state.attr("dim")); - const auto rank = dim.size(); - const auto rank_expected = grouped ? 3 : 2; - if (rank != rank_expected) { - cpp11::stop("Expected 'state' to be a %dd array", rank_expected); - } - const int n_state = obj->n_state(); - const int n_particles = - grouped ? obj->n_particles() : obj->n_particles() * obj->n_groups(); - const int n_groups = grouped ? obj->n_groups() : 1; - if (dim[0] != n_state) { - cpp11::stop("Expected the first dimension of 'state' to have size %d", - n_state); - } - const auto recycle_particle = n_particles > 1 && dim[1] == 1; - if (dim[1] != n_particles && dim[1] != 1) { - cpp11::stop("Expected the second dimension of 'state' to have size %d or 1", - n_particles); - } - const auto recycle_group = !grouped || (n_groups > 1 && dim[2] == 1); - if (grouped && dim[2] != n_groups && dim[2] != 1) { - cpp11::stop("Expected the third dimension of 'state' to have size %d or 1", - n_groups); - } - obj->set_state(REAL(r_state), recycle_particle, recycle_group); + set_state(*obj, r_state, grouped); return R_NilValue; } diff --git a/inst/include/dust2/r/filter.hpp b/inst/include/dust2/r/filter.hpp index cd0139e2..efc24864 100644 --- a/inst/include/dust2/r/filter.hpp +++ b/inst/include/dust2/r/filter.hpp @@ -57,13 +57,16 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars, template cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - bool grouped) { + cpp11::sexp r_initial, bool grouped) { auto *obj = cpp11::as_cpp>>(ptr).get(); if (r_pars != R_NilValue) { update_pars(obj->model, cpp11::as_cpp(r_pars), grouped); } - obj->run(); + if (r_initial != R_NilValue) { + set_state(obj->model, r_initial, grouped); + } + obj->run(r_initial == R_NilValue); const auto n_groups = obj->model.n_groups(); const auto n_particles = obj->model.n_particles(); diff --git a/inst/include/dust2/r/helpers.hpp b/inst/include/dust2/r/helpers.hpp index d1868350..2b5e8057 100644 --- a/inst/include/dust2/r/helpers.hpp +++ b/inst/include/dust2/r/helpers.hpp @@ -244,5 +244,38 @@ std::vector check_data(cpp11::list r_data, return data; } +template +void set_state(dust_cpu& obj, cpp11::sexp r_state, bool grouped) { + // Suppose that we have a n_state x n_particles x n_groups grouped + // system, we then require that we have a state array with rank 3; + // for an ungrouped system this will be rank 2 array. + auto dim = cpp11::as_cpp(r_state.attr("dim")); + const auto rank = dim.size(); + const auto rank_expected = grouped ? 3 : 2; + if (rank != rank_expected) { + cpp11::stop("Expected 'state' to be a %dd array", rank_expected); + } + const int n_state = obj.n_state(); + const int n_particles = + grouped ? obj.n_particles() : obj.n_particles() * obj.n_groups(); + const int n_groups = grouped ? obj.n_groups() : 1; + if (dim[0] != n_state) { + cpp11::stop("Expected the first dimension of 'state' to have size %d", + n_state); + } + const auto recycle_particle = n_particles > 1 && dim[1] == 1; + if (dim[1] != n_particles && dim[1] != 1) { + cpp11::stop("Expected the second dimension of 'state' to have size %d or 1", + n_particles); + } + + const auto recycle_group = !grouped || (n_groups > 1 && dim[2] == 1); + if (grouped && dim[2] != n_groups && dim[2] != 1) { + cpp11::stop("Expected the third dimension of 'state' to have size %d or 1", + n_groups); + } + obj.set_state(REAL(r_state), recycle_particle, recycle_group); +} + } } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 8461fd30..4bf549c8 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -76,10 +76,10 @@ extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_sta END_CPP11 } // sir.cpp -SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, bool grouped); -extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP grouped) { +SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP grouped) { BEGIN_CPP11 - return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(grouped))); + return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_initial), cpp11::as_cpp>(grouped))); END_CPP11 } // walk.cpp @@ -164,7 +164,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2}, {"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1}, {"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 7}, - {"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 3}, + {"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 4}, {"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3}, {"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7}, {"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2}, diff --git a/src/sir.cpp b/src/sir.cpp index 531975e6..086c0a4a 100644 --- a/src/sir.cpp +++ b/src/sir.cpp @@ -79,6 +79,6 @@ SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, [[cpp11::register]] SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, - bool grouped) { - return dust2::r::dust2_cpu_unfilter_run(ptr, r_pars, grouped); + cpp11::sexp r_initial, bool grouped) { + return dust2::r::dust2_cpu_unfilter_run(ptr, r_pars, r_initial, grouped); } diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R index d1fb89cd..546cb267 100644 --- a/tests/testthat/test-filter.R +++ b/tests/testthat/test-filter.R @@ -25,10 +25,39 @@ test_that("can run an unfilter", { obj <- dust2_cpu_sir_unfilter_alloc(base, time_start, time, dt, data, 1, 0) ptr <- obj[[1]] - expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, FALSE), f(pars1)) + expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, NULL, FALSE), f(pars1)) - expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars1, FALSE), f(pars1)) - expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars2, FALSE), f(pars2)) + expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars1, NULL, FALSE), f(pars1)) + expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars2, NULL, FALSE), f(pars2)) +}) + + +test_that("can run an unfilter with manually set state", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + state <- matrix(c(1000 - 17, 17, 0, 0, 0), ncol = 1) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + + ## Manually compute likelihood: + f <- function(pars) { + obj <- dust2_cpu_sir_alloc(pars, time_start, dt, 1, 0, NULL, TRUE) + ptr <- obj[[1]] + dust2_cpu_sir_set_state(ptr, state, FALSE) + incidence <- numeric(length(time)) + time0 <- c(time_start, time) + for (i in seq_along(time)) { + dust2_cpu_sir_run_steps(ptr, round((time[i] - time0[i]) / dt)) + incidence[i] <- dust2_cpu_sir_state(ptr, FALSE)[5, , drop = TRUE] + } + sum(dpois(1:4, incidence + 1e-6, log = TRUE)) + } + + obj <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 0) + ptr <- obj[[1]] + expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, state, FALSE), f(pars)) }) @@ -62,7 +91,7 @@ test_that("can run run unfilter on structured model", { obj <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 3) ptr <- obj[[1]] - expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, TRUE), f(pars)) + expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, NULL, TRUE), f(pars)) }) @@ -99,8 +128,8 @@ test_that("can run replicated unfilter", { obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 0) expect_equal( - dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, FALSE), - rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, FALSE), 5)) + dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, NULL, FALSE), + rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, NULL, FALSE), 5)) }) @@ -117,7 +146,8 @@ test_that("can run replicated structured unfilter", { obj1 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 5, 2) obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 2) + cmp <- dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, NULL, TRUE) expect_equal( - dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, TRUE), - matrix(rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, TRUE), each = 5), 5)) + dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, NULL, TRUE), + matrix(rep(cmp, each = 5), 5)) })