Skip to content

Commit

Permalink
Allow setting initial state into filter
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 16, 2024
1 parent fe39957 commit d236192
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 48 deletions.
4 changes: 2 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_d
.Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups)
}

dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, grouped) {
.Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, grouped)
dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, r_initial, grouped) {
.Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, r_initial, grouped)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
Expand Down
6 changes: 4 additions & 2 deletions inst/include/dust2/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ class unfilter {
}
}

void run() {
void run(bool set_initial) {
const auto n_times = step_.size();

model.set_time(time_start_);
model.set_state_initial();
if (set_initial) {
model.set_state_initial();
}
std::fill(ll_.begin(), ll_.end(), 0);

auto it_data = data_.begin();
Expand Down
29 changes: 1 addition & 28 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,34 +100,7 @@ SEXP dust2_cpu_set_state_initial(cpp11::sexp ptr) {
template <typename T>
SEXP dust2_cpu_set_state(cpp11::sexp ptr, cpp11::sexp r_state, bool grouped) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
// Suppose that we have a n_state x n_particles x n_groups grouped
// system, we then require that we have a state array with rank 3;
// for an ungrouped system this will be rank 2 array.
auto dim = cpp11::as_cpp<cpp11::integers>(r_state.attr("dim"));
const auto rank = dim.size();
const auto rank_expected = grouped ? 3 : 2;
if (rank != rank_expected) {
cpp11::stop("Expected 'state' to be a %dd array", rank_expected);
}
const int n_state = obj->n_state();
const int n_particles =
grouped ? obj->n_particles() : obj->n_particles() * obj->n_groups();
const int n_groups = grouped ? obj->n_groups() : 1;
if (dim[0] != n_state) {
cpp11::stop("Expected the first dimension of 'state' to have size %d",
n_state);
}
const auto recycle_particle = n_particles > 1 && dim[1] == 1;
if (dim[1] != n_particles && dim[1] != 1) {
cpp11::stop("Expected the second dimension of 'state' to have size %d or 1",
n_particles);
}
const auto recycle_group = !grouped || (n_groups > 1 && dim[2] == 1);
if (grouped && dim[2] != n_groups && dim[2] != 1) {
cpp11::stop("Expected the third dimension of 'state' to have size %d or 1",
n_groups);
}
obj->set_state(REAL(r_state), recycle_particle, recycle_group);
set_state(*obj, r_state, grouped);
return R_NilValue;
}

Expand Down
7 changes: 5 additions & 2 deletions inst/include/dust2/r/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,

template <typename T>
cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
cpp11::sexp r_initial, bool grouped) {
auto *obj =
cpp11::as_cpp<cpp11::external_pointer<unfilter<T>>>(ptr).get();
if (r_pars != R_NilValue) {
update_pars(obj->model, cpp11::as_cpp<cpp11::list>(r_pars), grouped);
}
obj->run();
if (r_initial != R_NilValue) {
set_state(obj->model, r_initial, grouped);
}
obj->run(r_initial == R_NilValue);

const auto n_groups = obj->model.n_groups();
const auto n_particles = obj->model.n_particles();
Expand Down
33 changes: 33 additions & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,5 +244,38 @@ std::vector<typename T::data_type> check_data(cpp11::list r_data,
return data;
}

template <typename T>
void set_state(dust_cpu<T>& obj, cpp11::sexp r_state, bool grouped) {
// Suppose that we have a n_state x n_particles x n_groups grouped
// system, we then require that we have a state array with rank 3;
// for an ungrouped system this will be rank 2 array.
auto dim = cpp11::as_cpp<cpp11::integers>(r_state.attr("dim"));
const auto rank = dim.size();
const auto rank_expected = grouped ? 3 : 2;
if (rank != rank_expected) {
cpp11::stop("Expected 'state' to be a %dd array", rank_expected);
}
const int n_state = obj.n_state();
const int n_particles =
grouped ? obj.n_particles() : obj.n_particles() * obj.n_groups();
const int n_groups = grouped ? obj.n_groups() : 1;
if (dim[0] != n_state) {
cpp11::stop("Expected the first dimension of 'state' to have size %d",
n_state);
}
const auto recycle_particle = n_particles > 1 && dim[1] == 1;
if (dim[1] != n_particles && dim[1] != 1) {
cpp11::stop("Expected the second dimension of 'state' to have size %d or 1",
n_particles);
}

const auto recycle_group = !grouped || (n_groups > 1 && dim[2] == 1);
if (grouped && dim[2] != n_groups && dim[2] != 1) {
cpp11::stop("Expected the third dimension of 'state' to have size %d or 1",
n_groups);
}
obj.set_state(REAL(r_state), recycle_particle, recycle_group);
}

}
}
8 changes: 4 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_sta
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP grouped) {
SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, cpp11::sexp r_initial, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP r_initial, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_pars), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_pars), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_initial), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// walk.cpp
Expand Down Expand Up @@ -164,7 +164,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2},
{"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1},
{"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 7},
{"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 3},
{"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 4},
{"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3},
{"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7},
{"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2},
Expand Down
4 changes: 2 additions & 2 deletions src/sir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars,

[[cpp11::register]]
SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
return dust2::r::dust2_cpu_unfilter_run<sir>(ptr, r_pars, grouped);
cpp11::sexp r_initial, bool grouped) {
return dust2::r::dust2_cpu_unfilter_run<sir>(ptr, r_pars, r_initial, grouped);
}
46 changes: 38 additions & 8 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,39 @@ test_that("can run an unfilter", {

obj <- dust2_cpu_sir_unfilter_alloc(base, time_start, time, dt, data, 1, 0)
ptr <- obj[[1]]
expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, FALSE), f(pars1))
expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, NULL, FALSE), f(pars1))

expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars1, FALSE), f(pars1))
expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars2, FALSE), f(pars2))
expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars1, NULL, FALSE), f(pars1))
expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars2, NULL, FALSE), f(pars2))
})


test_that("can run an unfilter with manually set state", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
state <- matrix(c(1000 - 17, 17, 0, 0, 0), ncol = 1)

time_start <- 0
time <- c(4, 8, 12, 16)
data <- lapply(1:4, function(i) list(incidence = i))
dt <- 1

## Manually compute likelihood:
f <- function(pars) {
obj <- dust2_cpu_sir_alloc(pars, time_start, dt, 1, 0, NULL, TRUE)
ptr <- obj[[1]]
dust2_cpu_sir_set_state(ptr, state, FALSE)
incidence <- numeric(length(time))
time0 <- c(time_start, time)
for (i in seq_along(time)) {
dust2_cpu_sir_run_steps(ptr, round((time[i] - time0[i]) / dt))
incidence[i] <- dust2_cpu_sir_state(ptr, FALSE)[5, , drop = TRUE]
}
sum(dpois(1:4, incidence + 1e-6, log = TRUE))
}

obj <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 0)
ptr <- obj[[1]]
expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, state, FALSE), f(pars))
})


Expand Down Expand Up @@ -62,7 +91,7 @@ test_that("can run run unfilter on structured model", {

obj <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 3)
ptr <- obj[[1]]
expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, TRUE), f(pars))
expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, NULL, TRUE), f(pars))
})


Expand Down Expand Up @@ -99,8 +128,8 @@ test_that("can run replicated unfilter", {
obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 0)

expect_equal(
dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, FALSE),
rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, FALSE), 5))
dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, NULL, FALSE),
rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, NULL, FALSE), 5))
})


Expand All @@ -117,7 +146,8 @@ test_that("can run replicated structured unfilter", {
obj1 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 5, 2)
obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 2)

cmp <- dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, NULL, TRUE)
expect_equal(
dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, TRUE),
matrix(rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, TRUE), each = 5), 5))
dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, NULL, TRUE),
matrix(rep(cmp, each = 5), 5))
})

0 comments on commit d236192

Please sign in to comment.