Skip to content

Commit

Permalink
Merge pull request #110 from mrc-ide/mrc-5906
Browse files Browse the repository at this point in the history
Improve coverage
  • Loading branch information
weshinsley authored Oct 28, 2024
2 parents e5aed27 + a4a24c6 commit fc83953
Show file tree
Hide file tree
Showing 16 changed files with 268 additions and 55 deletions.
8 changes: 8 additions & 0 deletions .covrignore
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
R/import-*.R
R/cpp11.R
R/dust.R
src/cpp11.cpp
src/logistic.cpp
src/malaria.cpp
src/sir.cpp
src/sirode.cpp
src/walk.cpp
11 changes: 10 additions & 1 deletion R/browser.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,21 @@ browser_env <- function(env, phase, time) {
parent <- browser_find_parent_env(sys.frames())
if (is.null(parent) || !isTRUE(parent$.dust_browser_continue)) {
browser_welcome_message(env, phase, time)
with(env, browser())
browse_env(env)
}
}
}


browse_env <- function(env) {
## This can't be mocked out due to how R handles browser, and we
## might need something a bit better before final submission to
## CRAN.
with(env, browser()) # nocov
}



browser_find_parent_env <- function(frames, drop = 1) {
## Finds the outermost call to dust2, dropping the last frame(s),
## which we expect to be dust calls.
Expand Down
10 changes: 3 additions & 7 deletions R/interface-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,8 @@ print.dust_likelihood <- function(x, ...) {
## TODO:
## * has gradient
## Link to docs
if (attr(x$generator, "properties")$time_type == "discrete") {
cli::cli_bullets(c(
i = "The system runs in discrete time with dt = {x$time_control$dt}"))
} else {
cli::cli_bullets(c(
i = "The system runs in continuous time"))
}
time_type <- attr(x$generator, "properties")$time_type
cli::cli_bullets(
c(i = describe_time(time_type, NULL, x$time_control$dt)))
invisible(x)
}
63 changes: 31 additions & 32 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -533,22 +533,8 @@ print.dust_system <- function(x, ...) {
cli::cli_bullets(c(
i = "This system has 'adjoint' support, and can compute gradients"))
}
if (x$properties$time_type == "discrete") {
cli::cli_bullets(c(
i = "This system runs in discrete time with dt = {x$time_control$dt}"))
} else if (x$properties$time_type == "mixed") {
if (is.null(x$time_control$dt)) {
cli::cli_bullets(c(
i = "This system runs in continuous time (discrete time disabled)"))
} else {
cli::cli_bullets(c(
i = paste("This system runs in both continuous time and",
"discrete time with dt = {x$time_control$dt}")))
}
} else {
cli::cli_bullets(c(
i = "This system runs in continuous time"))
}
cli::cli_bullets(
c(i = describe_time(x$properties$time_type, NULL, x$time_control$dt)))
invisible(x)
}

Expand All @@ -565,25 +551,38 @@ print.dust_system_generator <- function(x, ...) {
cli::cli_bullets(c(
i = "This system has 'compare_data' support"))
}
if (properties$time_type == "discrete") {
cli::cli_bullets(c(
i = paste("This system runs in discrete time",
"with a default dt of {default_dt}")))
} else if (properties$time_type == "mixed") {
if (is.null(default_dt)) {
cli::cli_bullets(c(
i = paste("This system runs in both continuous time",
"and discrete time with discrete time disabled by default")))
cli::cli_bullets(
c(i = describe_time(properties$time_type, default_dt)))
invisible(x)
}


describe_time <- function(time_type, default_dt, dt) {
if (time_type == "continuous") {
"This system runs in continuous time"
} else if (time_type == "discrete") {
prefix <- "This system runs in discrete time"
if (rlang::is_missing(dt)) {
cli::format_inline("{prefix} with a default dt of {default_dt}")
} else {
cli::cli_bullets(c(
i = paste("This system runs in both continuous time",
"and discrete time with a default dt of {default_dt}")))
cli::format_inline("{prefix} with dt = {dt}")
}
} else if (time_type == "mixed") {
prefix <- "This system runs in both continuous time and discrete time"
if (rlang::is_missing(dt)) {
if (is.null(default_dt)) {
cli::format_inline("{prefix} with discrete time disabled by default")
} else {
cli::format_inline("{prefix} with a default dt of {default_dt}")
}
} else {
if (is.null(dt)) {
cli::format_inline("{prefix} with discrete time disabled")
} else {
cli::format_inline("{prefix} with dt = {dt}")
}
}
} else {
cli::cli_bullets(c(
i = "This system runs in continuous time"))
}
invisible(x)
}


Expand Down
2 changes: 1 addition & 1 deletion R/metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ parse_metadata_default_dt <- function(data, time_type, call = NULL) {
"Can't use '[[dust::default_dt()]]' with continuous-time systems",
call = call)
}
if (length(data) != 1 || nzchar(names(data))) {
if (length(data) != 1 || nzchar(rlang::names2(data))) {
cli::cli_abort(
"Expected a single unnamed argument to '[[dust2::default_dt()]]'",
call = call)
Expand Down
2 changes: 1 addition & 1 deletion inst/examples/malaria.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class malaria {
if (std::isnan(data.positive) || std::isnan(data.tested)) {
return 0;
}
const real_type Ih = state[1]; // Ih
const real_type Ih = state[1];
return monty::density::binomial(data.positive, data.tested, Ih, true);
}
};
14 changes: 4 additions & 10 deletions inst/include/dust2/trajectories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,10 @@ class trajectories {
const auto iter_state = state_.begin() + i * len_state_;
for (size_t j = 0; j < n_groups_; ++j) {
const auto offset_state = j * n_state_ * n_particles_;
if (use_select_particle) {
const auto offset_j = select_particle[j] * n_state_;
iter = std::copy_n(iter_state + offset_state + offset_j,
n_state_,
iter);
} else {
iter = std::copy_n(iter_state + offset_state,
n_state_ * n_particles_,
iter);
}
const auto offset_j = select_particle[j] * n_state_;
iter = std::copy_n(iter_state + offset_state + offset_j,
n_state_,
iter);
}
}
}
Expand Down
15 changes: 12 additions & 3 deletions tests/testthat/helper-dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
## system; once we have a generic system interface we can make this more
## generic. It just redoes the same logic as in the C++ code but is
## easier to read (and quite a lot slower due to churn in state).
sir_filter_manual <- function(pars, time_start, data, dt, n_particles,
seed) {
filter_manual <- function(generator, pars, time_start, data, dt, n_particles,
seed) {
r <- monty::monty_rng$new(n_streams = 1, seed = seed)
seed <- monty::monty_rng$new(n_streams = 1, seed = seed)$jump()$state()

obj <- dust_system_create(sir(), pars, n_particles,
obj <- dust_system_create(generator, pars, n_particles,
time = time_start, dt = dt, seed = seed)
n_state <- nrow(dust_system_state(obj))
n_time <- nrow(data)
Expand Down Expand Up @@ -44,6 +44,15 @@ sir_filter_manual <- function(pars, time_start, data, dt, n_particles,
}
}

sir_filter_manual <- function(...) {
filter_manual(sir, ...)
}


malaria_filter_manual <- function(...) {
filter_manual(malaria, ...)
}


skip_for_compilation <- function() {
skip_on_cran()
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-browser.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,24 @@ test_that("set browser sentinal if requested", {
mockery::expect_called(mock_find_env, 1)
expect_true(e$.dust_browser_continue)
})


test_that("can browse environment", {
skip_if_not_installed("mockery")
mock_browse_env <- mockery::mock()
mockery::stub(browser_env, "browse_env", mock_browse_env)
env <- new.env()
env$a <- 1
withr::with_options(list(dust.browser_enabled = FALSE),
browser_env(env, "phase", 1))
mockery::expect_called(mock_browse_env, 0)

withr::with_options(
list(dust.browser_enabled = NULL),
expect_message(browser_env(env, "phase", 1),
"dust debug ('phase'; time = 1):",
fixed = TRUE))

mockery::expect_called(mock_browse_env, 1)
expect_equal(mockery::mock_args(mock_browse_env)[[1]], list(env))
})
38 changes: 38 additions & 0 deletions tests/testthat/test-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ test_that("can only use pars = NULL on initialised filter", {
})


test_that("can only use index_group on initialised filter", {
pars <- list(
list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6),
list(beta = 0.2, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6))
time_start <- 0
data <- data.frame(time = rep(c(4, 8, 12, 16), 2),
group = rep(1:2, each = 4),
incidence = c(1:4, 2:5))
obj <- dust_filter_create(sir(), time_start, data,
n_particles = 10, n_groups = 2)
expect_error(dust_likelihood_run(obj, pars[1], index_group = 1),
"'index_group' must be NULL, as 'obj' is not initialised")
})


test_that("can run particle filter and save trajectories", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)

Expand All @@ -49,6 +64,8 @@ test_that("can run particle filter and save trajectories", {
n_particles = n_particles, seed = seed)
expect_error(dust_likelihood_last_trajectories(obj),
"Trajectories are not current")
expect_error(dust_likelihood_last_state(obj),
"State is not current")
res1 <- dust_likelihood_run(obj, pars)
expect_error(dust_likelihood_last_trajectories(obj),
"Trajectories are not current")
Expand Down Expand Up @@ -648,3 +665,24 @@ test_that("can run continuous-time filter", {
"Can't use 'dust_filter_create()' with continuous-time models",
fixed = TRUE)
})


test_that("can run filter on mixed time models", {
pars <- list()

time_start <- 0
data <- data.frame(time = c(4, 8, 12, 16),
tested = c(2, 4, 6, 8),
positive = c(1, 2, 3, 3))
dt <- 1
n_particles <- 100
seed <- 42

obj <- dust_filter_create(malaria(), time_start, data, dt = dt,
n_particles = n_particles, seed = seed)
res <- replicate(5, dust_likelihood_run(obj, pars))

cmp_filter <- filter_manual(
malaria, pars, time_start, data, dt, n_particles, seed)
expect_equal(res, replicate(5, cmp_filter(NULL)$log_likelihood))
})
33 changes: 33 additions & 0 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,36 @@ test_that("format dimensions as a string", {
n_groups = 1)),
"single particle")
})


test_that("can describe time", {
expect_equal(
describe_time("continuous", NULL),
"This system runs in continuous time")
expect_equal(
describe_time("discrete", 1),
"This system runs in discrete time with a default dt of 1")
expect_equal(
describe_time("mixed", NULL),
paste("This system runs in both continuous time and discrete time",
"with discrete time disabled by default"))
expect_equal(
describe_time("mixed", 1),
paste("This system runs in both continuous time and discrete time",
"with a default dt of 1"))

expect_equal(
describe_time("continuous", NULL, NULL),
"This system runs in continuous time")
expect_equal(
describe_time("discrete", 1, 0.5),
"This system runs in discrete time with dt = 0.5")
expect_equal(
describe_time("mixed", NULL, NULL),
paste("This system runs in both continuous time and discrete time",
"with discrete time disabled"))
expect_equal(
describe_time("mixed", NULL, 0.5),
paste("This system runs in both continuous time and discrete time",
"with dt = 0.5"))
})
23 changes: 23 additions & 0 deletions tests/testthat/test-malaria.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,26 @@ test_that("can apply stochastic updates by setting dt", {
expect_equal(y0[-i_beta, ], y[-i_beta, , i + 1])
}
})


test_that("can print a mixed generator", {
res <- evaluate_promise(withVisible(print(malaria)))
expect_mapequal(res$result, list(value = malaria, visible = FALSE))
expect_match(res$messages, "<dust_system_generator: malaria>",
fixed = TRUE, all = FALSE)
expect_match(
res$messages,
"This system runs in both continuous time and discrete time",
all = FALSE)
})


test_that("Can compare to data", {
sys <- dust_system_create(malaria(), list(), n_particles = 10, dt = 1)
dust_system_set_state_initial(sys)
dust_system_run_to_time(sys, 10)
s <- dust_unpack_state(sys, dust_system_state(sys))
d <- list(tested = 4, positive = 2)
expect_equal(dust_system_compare_data(sys, d),
dbinom(d$positive, d$tested, s$Ih, log = TRUE))
})
32 changes: 32 additions & 0 deletions tests/testthat/test-metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,35 @@ test_that("can validate default dt in discrete time models", {
"Expected '[[dust2::default_dt()]]' to be the inverse of an integer",
fixed = TRUE)
})


test_that("validate default dt", {
d <- data.frame(decoration = "dust2::default_dt",
params = I(list(alist(0.5))))

expect_equal(parse_metadata_default_dt(d[-1, ], "discrete"), 1)
expect_equal(parse_metadata_default_dt(d, "discrete"), 0.5)

expect_error(
parse_metadata_default_dt(d, "continuous"),
"Can't use '[[dust::default_dt()]]' with continuous-time systems",
fixed = TRUE)

d$params[[1]] <- list(quote(a))
expect_error(
parse_metadata_default_dt(d, "discrete"),
"Expected a numerical argument to '[[dust2::default_dt()]]'",
fixed = TRUE)

d$params[[1]] <- list(dt = 1)
expect_error(
parse_metadata_default_dt(d, "discrete"),
"Expected a single unnamed argument to '[[dust2::default_dt()]]'",
fixed = TRUE)

d$params[[1]] <- list(0.37)
expect_error(
parse_metadata_default_dt(d, "discrete"),
"Expected '[[dust2::default_dt()]]' to be the inverse of an integer",
fixed = TRUE)
})
Loading

0 comments on commit fc83953

Please sign in to comment.