Skip to content

Commit

Permalink
Add tests for grouped history saving
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 30, 2024
1 parent 630c864 commit 62dc692
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
7 changes: 7 additions & 0 deletions inst/include/dust2/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class unfilter {

void run(bool set_initial, bool save_history) {
history_is_current_ = false;
if (save_history) {
history_.reset();
}
const auto n_times = step_.size();

model.set_time(time_start_);
Expand Down Expand Up @@ -145,6 +148,10 @@ class filter {
}

void run(bool set_initial, bool save_history) {
history_is_current_ = false;
if (save_history) {
history_.reset();
}
const auto n_times = step_.size();

model.set_time(time_start_);
Expand Down
45 changes: 42 additions & 3 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,40 @@ test_that("can run replicated structured unfilter", {
})


test_that("can save history from structured unfilter", {
pars <- list(
list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6),
list(beta = 0.2, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6))
time_start <- 0
time <- c(4, 8, 12, 16)
data <- lapply(1:4, function(i) {
lapply(seq_len(2), function(j) list(incidence = 2 * (i - 1) + j))
})
data1 <- lapply(data, function(x) x[[1]])
data2 <- lapply(data, function(x) x[[2]])
dt <- 1
obj <- dust_unfilter_create(sir(), pars, time_start, time, data,
n_groups = 2)
obj1 <- dust_unfilter_create(sir(), pars[[1]], time_start, time, data1)
obj2 <- dust_unfilter_create(sir(), pars[[2]], time_start, time, data2)

ll <- dust_unfilter_run(obj, save_history = TRUE)
ll1 <- dust_unfilter_run(obj1, save_history = TRUE)
ll2 <- dust_unfilter_run(obj2, save_history = TRUE)

expect_equal(ll, c(ll1, ll2))
h <- dust_unfilter_last_history(obj)
h1 <- dust_unfilter_last_history(obj1)
h2 <- dust_unfilter_last_history(obj2)

expect_equal(dim(h), c(5, 1, 2, 4))
expect_equal(dim(h1), c(5, 1, 4))
expect_equal(dim(h2), c(5, 1, 4))
expect_equal(array(h[, , 1, ], dim(h1)), h1)
expect_equal(array(h[, , 2, ], dim(h2)), h2)
})


test_that("can run particle filter", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)

Expand Down Expand Up @@ -309,14 +343,14 @@ test_that("can run a nested particle filter and get the same result", {
s <- dust_filter_rng_state(obj)
expect_equal(s, r)

res <- replicate(20, dust_filter_run(obj))
res <- replicate(20, dust_filter_run(obj, save_history = TRUE))

## now compare:
data1 <- lapply(data, "[[", 1)
obj1 <- dust_filter_create(sir(), pars[[1]], time_start, time, data1,
n_particles = n_particles, seed = seed)
s1 <- dust_filter_rng_state(obj1)
res1 <- replicate(20, dust_filter_run(obj1))
res1 <- replicate(20, dust_filter_run(obj1, save_history = TRUE))
expect_equal(res1, res[1, ])
expect_equal(s1, s[1:3232])

Expand All @@ -325,9 +359,14 @@ test_that("can run a nested particle filter and get the same result", {
obj2 <- dust_filter_create(sir(), pars[[2]], time_start, time, data2,
n_particles = n_particles, seed = seed2)
s2 <- dust_filter_rng_state(obj2)
res2 <- replicate(20, dust_filter_run(obj2))
res2 <- replicate(20, dust_filter_run(obj2, save_history = TRUE))
expect_equal(res2, res[2, ])
expect_equal(s2, s[3233:6464])

h <- dust_filter_last_history(obj)
expect_equal(dim(h), c(5, 100, 2, 4))
expect_equal(h[, , 1, ], dust_filter_last_history(obj1))
expect_equal(h[, , 2, ], dust_filter_last_history(obj2))
})


Expand Down

0 comments on commit 62dc692

Please sign in to comment.