Skip to content

Commit

Permalink
Merge pull request #84 from mrc-ide/mrc-5821
Browse files Browse the repository at this point in the history
Save packer information into likelihood (filter and unfilter)
  • Loading branch information
weshinsley authored Sep 30, 2024
2 parents 92f9a36 + 420c892 commit 76f45d6
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 7 deletions.
1 change: 1 addition & 0 deletions R/interface-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ filter_create <- function(obj, pars) {
obj$initial_rng_state),
obj)
obj$initial_rng_state <- NULL
obj$packer_state <- monty::monty_packer(array = obj$packing_state)
}


Expand Down
5 changes: 5 additions & 0 deletions R/interface-unfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ unfilter_create <- function(unfilter, pars) {
inputs$n_threads,
inputs$index_state),
unfilter)
unfilter$packer_state <- monty::monty_packer(array = unfilter$packing_state)
if (!is.null(unfilter$packing_gradient)) {
unfilter$packer_gradient <-
monty::monty_packer(array = unfilter$packing_gradient)
}
}
14 changes: 11 additions & 3 deletions R/tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,16 @@ dust_unpack_index <- function(obj) {


get_unpacker <- function(obj, call = parent.frame()) {
## Once mrc-5806 is merged, we can add this into the filter/unfilter
## too and do the same basic idea.
assert_is(obj, "dust_system", call = call)
if (inherits(obj, "dust_likelihood")) {
if (is.null(obj$packer_state)) {
cli::cli_abort(
c("Packer is not yet ready",
i = "Likelihood has not yet been run"))
}
} else if (!inherits(obj, "dust_system")) {
cli::cli_abort(
"Expected 'obj' to be a 'dust_system' or a 'dust_likelihood'",
arg = "obj", call = call)
}
obj$packer_state
}
3 changes: 2 additions & 1 deletion inst/include/dust2/r/continuous/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ cpp11::sexp dust2_continuous_filter_alloc(cpp11::list r_pars,
cpp11::external_pointer<filter<dust_continuous<T>>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->sys.n_state());
cpp11::sexp r_packing_state = packing_to_r(obj->sys.packing_state());

using namespace cpp11::literals;
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state};
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state, "packing_state"_nm = r_packing_state};
}

}
Expand Down
4 changes: 3 additions & 1 deletion inst/include/dust2/r/continuous/unfilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ cpp11::sexp dust2_continuous_unfilter_alloc(cpp11::list r_pars,
cpp11::external_pointer<unfilter<dust_continuous<T>>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->sys.n_state());
cpp11::sexp r_packing_state = packing_to_r(obj->sys.packing_state());
cpp11::sexp r_packing_gradient = packing_to_r(obj->sys.packing_gradient());

using namespace cpp11::literals;
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state};
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state, "packing_state"_nm = r_packing_state, "packing_gradient"_nm = r_packing_gradient};
}

}
Expand Down
3 changes: 2 additions & 1 deletion inst/include/dust2/r/discrete/filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ cpp11::sexp dust2_discrete_filter_alloc(cpp11::list r_pars,
cpp11::external_pointer<filter<dust_discrete<T>>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->sys.n_state());
cpp11::sexp r_packing_state = packing_to_r(obj->sys.packing_state());

using namespace cpp11::literals;
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state};
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state, "packing_state"_nm = r_packing_state};
}

}
Expand Down
4 changes: 3 additions & 1 deletion inst/include/dust2/r/discrete/unfilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ cpp11::sexp dust2_discrete_unfilter_alloc(cpp11::list r_pars,
cpp11::external_pointer<unfilter<dust_discrete<T>>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->sys.n_state());
cpp11::sexp r_packing_state = packing_to_r(obj->sys.packing_state());
cpp11::sexp r_packing_gradient = packing_to_r(obj->sys.packing_gradient());

using namespace cpp11::literals;
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state};
return cpp11::writable::list{"ptr"_nm = ptr, "n_state"_nm = r_n_state, "packing_state"_nm = r_packing_state, "packing_gradient"_nm = r_packing_gradient};
}

}
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ test_that("can extract final state from a filter", {
expect_equal(dim(s), c(5, 10))
expect_equal(s, h[, , 4])

expect_equal(names(dust_unpack_state(obj, s)),
c("S", "I", "R", "cases_cumul", "cases_inc"))

dust_likelihood_run(obj, pars, save_history = FALSE)
expect_error(dust_likelihood_last_history(obj), "History is not current")
expect_no_error(dust_likelihood_last_state(obj))
Expand Down
18 changes: 18 additions & 0 deletions tests/testthat/test-tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ test_that("can unpack state from systems with several particles", {
expect_equal(s2, sys$packer_state$unpack(s))
expect_equal(lengths(s2, FALSE), rep(10, 5))
})


test_that("can't get unpacker from filter before running", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
time_start <- 0
data <- data.frame(time = c(4, 8, 12, 16), incidence = 1:4)
obj <- dust_filter_create(sir(), time_start, data, n_particles = 10)
expect_error(
dust_unpack_state(obj, numeric()),
"Packer is not yet ready")
})


test_that("can't get unpacker from unknown object", {
expect_error(
dust_unpack_state(NULL, numeric()),
"Expected 'obj' to be a 'dust_system' or a 'dust_likelihood'")
})

0 comments on commit 76f45d6

Please sign in to comment.