From a4a24c6d45fff63e992ad0dfa8472b70e2e931fb Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 28 Oct 2024 07:25:05 +0000 Subject: [PATCH] Improve coverage --- .covrignore | 8 ++++ R/browser.R | 11 ++++- R/interface-likelihood.R | 10 ++--- R/interface.R | 63 ++++++++++++++--------------- R/metadata.R | 2 +- inst/examples/malaria.cpp | 2 +- inst/include/dust2/trajectories.hpp | 14 ++----- tests/testthat/helper-dust.R | 15 +++++-- tests/testthat/test-browser.R | 21 ++++++++++ tests/testthat/test-filter.R | 38 +++++++++++++++++ tests/testthat/test-interface.R | 33 +++++++++++++++ tests/testthat/test-malaria.R | 23 +++++++++++ tests/testthat/test-metadata.R | 32 +++++++++++++++ tests/testthat/test-monty.R | 36 +++++++++++++++++ tests/testthat/test-util.R | 9 +++++ tests/testthat/test-walk.R | 6 +++ 16 files changed, 268 insertions(+), 55 deletions(-) diff --git a/.covrignore b/.covrignore index d90e23c9..635cf5df 100644 --- a/.covrignore +++ b/.covrignore @@ -1 +1,9 @@ R/import-*.R +R/cpp11.R +R/dust.R +src/cpp11.cpp +src/logistic.cpp +src/malaria.cpp +src/sir.cpp +src/sirode.cpp +src/walk.cpp diff --git a/R/browser.R b/R/browser.R index a9152c13..4784a4ac 100644 --- a/R/browser.R +++ b/R/browser.R @@ -105,12 +105,21 @@ browser_env <- function(env, phase, time) { parent <- browser_find_parent_env(sys.frames()) if (is.null(parent) || !isTRUE(parent$.dust_browser_continue)) { browser_welcome_message(env, phase, time) - with(env, browser()) + browse_env(env) } } } +browse_env <- function(env) { + ## This can't be mocked out due to how R handles browser, and we + ## might need something a bit better before final submission to + ## CRAN. + with(env, browser()) # nocov +} + + + browser_find_parent_env <- function(frames, drop = 1) { ## Finds the outermost call to dust2, dropping the last frame(s), ## which we expect to be dust calls. diff --git a/R/interface-likelihood.R b/R/interface-likelihood.R index 06017b77..e3395bee 100644 --- a/R/interface-likelihood.R +++ b/R/interface-likelihood.R @@ -292,12 +292,8 @@ print.dust_likelihood <- function(x, ...) { ## TODO: ## * has gradient ## Link to docs - if (attr(x$generator, "properties")$time_type == "discrete") { - cli::cli_bullets(c( - i = "The system runs in discrete time with dt = {x$time_control$dt}")) - } else { - cli::cli_bullets(c( - i = "The system runs in continuous time")) - } + time_type <- attr(x$generator, "properties")$time_type + cli::cli_bullets( + c(i = describe_time(time_type, NULL, x$time_control$dt))) invisible(x) } diff --git a/R/interface.R b/R/interface.R index 97609b52..e413412d 100644 --- a/R/interface.R +++ b/R/interface.R @@ -533,22 +533,8 @@ print.dust_system <- function(x, ...) { cli::cli_bullets(c( i = "This system has 'adjoint' support, and can compute gradients")) } - if (x$properties$time_type == "discrete") { - cli::cli_bullets(c( - i = "This system runs in discrete time with dt = {x$time_control$dt}")) - } else if (x$properties$time_type == "mixed") { - if (is.null(x$time_control$dt)) { - cli::cli_bullets(c( - i = "This system runs in continuous time (discrete time disabled)")) - } else { - cli::cli_bullets(c( - i = paste("This system runs in both continuous time and", - "discrete time with dt = {x$time_control$dt}"))) - } - } else { - cli::cli_bullets(c( - i = "This system runs in continuous time")) - } + cli::cli_bullets( + c(i = describe_time(x$properties$time_type, NULL, x$time_control$dt))) invisible(x) } @@ -565,25 +551,38 @@ print.dust_system_generator <- function(x, ...) { cli::cli_bullets(c( i = "This system has 'compare_data' support")) } - if (properties$time_type == "discrete") { - cli::cli_bullets(c( - i = paste("This system runs in discrete time", - "with a default dt of {default_dt}"))) - } else if (properties$time_type == "mixed") { - if (is.null(default_dt)) { - cli::cli_bullets(c( - i = paste("This system runs in both continuous time", - "and discrete time with discrete time disabled by default"))) + cli::cli_bullets( + c(i = describe_time(properties$time_type, default_dt))) + invisible(x) +} + + +describe_time <- function(time_type, default_dt, dt) { + if (time_type == "continuous") { + "This system runs in continuous time" + } else if (time_type == "discrete") { + prefix <- "This system runs in discrete time" + if (rlang::is_missing(dt)) { + cli::format_inline("{prefix} with a default dt of {default_dt}") } else { - cli::cli_bullets(c( - i = paste("This system runs in both continuous time", - "and discrete time with a default dt of {default_dt}"))) + cli::format_inline("{prefix} with dt = {dt}") + } + } else if (time_type == "mixed") { + prefix <- "This system runs in both continuous time and discrete time" + if (rlang::is_missing(dt)) { + if (is.null(default_dt)) { + cli::format_inline("{prefix} with discrete time disabled by default") + } else { + cli::format_inline("{prefix} with a default dt of {default_dt}") + } + } else { + if (is.null(dt)) { + cli::format_inline("{prefix} with discrete time disabled") + } else { + cli::format_inline("{prefix} with dt = {dt}") + } } - } else { - cli::cli_bullets(c( - i = "This system runs in continuous time")) } - invisible(x) } diff --git a/R/metadata.R b/R/metadata.R index f8b439a6..f9a934f8 100644 --- a/R/metadata.R +++ b/R/metadata.R @@ -109,7 +109,7 @@ parse_metadata_default_dt <- function(data, time_type, call = NULL) { "Can't use '[[dust::default_dt()]]' with continuous-time systems", call = call) } - if (length(data) != 1 || nzchar(names(data))) { + if (length(data) != 1 || nzchar(rlang::names2(data))) { cli::cli_abort( "Expected a single unnamed argument to '[[dust2::default_dt()]]'", call = call) diff --git a/inst/examples/malaria.cpp b/inst/examples/malaria.cpp index 5cc5006c..e1804fb9 100644 --- a/inst/examples/malaria.cpp +++ b/inst/examples/malaria.cpp @@ -131,7 +131,7 @@ class malaria { if (std::isnan(data.positive) || std::isnan(data.tested)) { return 0; } - const real_type Ih = state[1]; // Ih + const real_type Ih = state[1]; return monty::density::binomial(data.positive, data.tested, Ih, true); } }; diff --git a/inst/include/dust2/trajectories.hpp b/inst/include/dust2/trajectories.hpp index 89371f8d..9cecc0df 100644 --- a/inst/include/dust2/trajectories.hpp +++ b/inst/include/dust2/trajectories.hpp @@ -142,16 +142,10 @@ class trajectories { const auto iter_state = state_.begin() + i * len_state_; for (size_t j = 0; j < n_groups_; ++j) { const auto offset_state = j * n_state_ * n_particles_; - if (use_select_particle) { - const auto offset_j = select_particle[j] * n_state_; - iter = std::copy_n(iter_state + offset_state + offset_j, - n_state_, - iter); - } else { - iter = std::copy_n(iter_state + offset_state, - n_state_ * n_particles_, - iter); - } + const auto offset_j = select_particle[j] * n_state_; + iter = std::copy_n(iter_state + offset_state + offset_j, + n_state_, + iter); } } } diff --git a/tests/testthat/helper-dust.R b/tests/testthat/helper-dust.R index 79a2d389..7e53d51b 100644 --- a/tests/testthat/helper-dust.R +++ b/tests/testthat/helper-dust.R @@ -2,12 +2,12 @@ ## system; once we have a generic system interface we can make this more ## generic. It just redoes the same logic as in the C++ code but is ## easier to read (and quite a lot slower due to churn in state). -sir_filter_manual <- function(pars, time_start, data, dt, n_particles, - seed) { +filter_manual <- function(generator, pars, time_start, data, dt, n_particles, + seed) { r <- monty::monty_rng$new(n_streams = 1, seed = seed) seed <- monty::monty_rng$new(n_streams = 1, seed = seed)$jump()$state() - obj <- dust_system_create(sir(), pars, n_particles, + obj <- dust_system_create(generator, pars, n_particles, time = time_start, dt = dt, seed = seed) n_state <- nrow(dust_system_state(obj)) n_time <- nrow(data) @@ -44,6 +44,15 @@ sir_filter_manual <- function(pars, time_start, data, dt, n_particles, } } +sir_filter_manual <- function(...) { + filter_manual(sir, ...) +} + + +malaria_filter_manual <- function(...) { + filter_manual(malaria, ...) +} + skip_for_compilation <- function() { skip_on_cran() diff --git a/tests/testthat/test-browser.R b/tests/testthat/test-browser.R index 5f262ef3..59a2c1d8 100644 --- a/tests/testthat/test-browser.R +++ b/tests/testthat/test-browser.R @@ -72,3 +72,24 @@ test_that("set browser sentinal if requested", { mockery::expect_called(mock_find_env, 1) expect_true(e$.dust_browser_continue) }) + + +test_that("can browse environment", { + skip_if_not_installed("mockery") + mock_browse_env <- mockery::mock() + mockery::stub(browser_env, "browse_env", mock_browse_env) + env <- new.env() + env$a <- 1 + withr::with_options(list(dust.browser_enabled = FALSE), + browser_env(env, "phase", 1)) + mockery::expect_called(mock_browse_env, 0) + + withr::with_options( + list(dust.browser_enabled = NULL), + expect_message(browser_env(env, "phase", 1), + "dust debug ('phase'; time = 1):", + fixed = TRUE)) + + mockery::expect_called(mock_browse_env, 1) + expect_equal(mockery::mock_args(mock_browse_env)[[1]], list(env)) +}) diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R index 1d582ca4..1aeb4a75 100644 --- a/tests/testthat/test-filter.R +++ b/tests/testthat/test-filter.R @@ -36,6 +36,21 @@ test_that("can only use pars = NULL on initialised filter", { }) +test_that("can only use index_group on initialised filter", { + pars <- list( + list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6), + list(beta = 0.2, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)) + time_start <- 0 + data <- data.frame(time = rep(c(4, 8, 12, 16), 2), + group = rep(1:2, each = 4), + incidence = c(1:4, 2:5)) + obj <- dust_filter_create(sir(), time_start, data, + n_particles = 10, n_groups = 2) + expect_error(dust_likelihood_run(obj, pars[1], index_group = 1), + "'index_group' must be NULL, as 'obj' is not initialised") +}) + + test_that("can run particle filter and save trajectories", { pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) @@ -49,6 +64,8 @@ test_that("can run particle filter and save trajectories", { n_particles = n_particles, seed = seed) expect_error(dust_likelihood_last_trajectories(obj), "Trajectories are not current") + expect_error(dust_likelihood_last_state(obj), + "State is not current") res1 <- dust_likelihood_run(obj, pars) expect_error(dust_likelihood_last_trajectories(obj), "Trajectories are not current") @@ -648,3 +665,24 @@ test_that("can run continuous-time filter", { "Can't use 'dust_filter_create()' with continuous-time models", fixed = TRUE) }) + + +test_that("can run filter on mixed time models", { + pars <- list() + + time_start <- 0 + data <- data.frame(time = c(4, 8, 12, 16), + tested = c(2, 4, 6, 8), + positive = c(1, 2, 3, 3)) + dt <- 1 + n_particles <- 100 + seed <- 42 + + obj <- dust_filter_create(malaria(), time_start, data, dt = dt, + n_particles = n_particles, seed = seed) + res <- replicate(5, dust_likelihood_run(obj, pars)) + + cmp_filter <- filter_manual( + malaria, pars, time_start, data, dt, n_particles, seed) + expect_equal(res, replicate(5, cmp_filter(NULL)$log_likelihood)) +}) diff --git a/tests/testthat/test-interface.R b/tests/testthat/test-interface.R index ef1a43d0..36372b4a 100644 --- a/tests/testthat/test-interface.R +++ b/tests/testthat/test-interface.R @@ -184,3 +184,36 @@ test_that("format dimensions as a string", { n_groups = 1)), "single particle") }) + + +test_that("can describe time", { + expect_equal( + describe_time("continuous", NULL), + "This system runs in continuous time") + expect_equal( + describe_time("discrete", 1), + "This system runs in discrete time with a default dt of 1") + expect_equal( + describe_time("mixed", NULL), + paste("This system runs in both continuous time and discrete time", + "with discrete time disabled by default")) + expect_equal( + describe_time("mixed", 1), + paste("This system runs in both continuous time and discrete time", + "with a default dt of 1")) + + expect_equal( + describe_time("continuous", NULL, NULL), + "This system runs in continuous time") + expect_equal( + describe_time("discrete", 1, 0.5), + "This system runs in discrete time with dt = 0.5") + expect_equal( + describe_time("mixed", NULL, NULL), + paste("This system runs in both continuous time and discrete time", + "with discrete time disabled")) + expect_equal( + describe_time("mixed", NULL, 0.5), + paste("This system runs in both continuous time and discrete time", + "with dt = 0.5")) +}) diff --git a/tests/testthat/test-malaria.R b/tests/testthat/test-malaria.R index 6afdf869..05444d12 100644 --- a/tests/testthat/test-malaria.R +++ b/tests/testthat/test-malaria.R @@ -26,3 +26,26 @@ test_that("can apply stochastic updates by setting dt", { expect_equal(y0[-i_beta, ], y[-i_beta, , i + 1]) } }) + + +test_that("can print a mixed generator", { + res <- evaluate_promise(withVisible(print(malaria))) + expect_mapequal(res$result, list(value = malaria, visible = FALSE)) + expect_match(res$messages, "", + fixed = TRUE, all = FALSE) + expect_match( + res$messages, + "This system runs in both continuous time and discrete time", + all = FALSE) +}) + + +test_that("Can compare to data", { + sys <- dust_system_create(malaria(), list(), n_particles = 10, dt = 1) + dust_system_set_state_initial(sys) + dust_system_run_to_time(sys, 10) + s <- dust_unpack_state(sys, dust_system_state(sys)) + d <- list(tested = 4, positive = 2) + expect_equal(dust_system_compare_data(sys, d), + dbinom(d$positive, d$tested, s$Ih, log = TRUE)) +}) diff --git a/tests/testthat/test-metadata.R b/tests/testthat/test-metadata.R index f08b8c51..518ba2d0 100644 --- a/tests/testthat/test-metadata.R +++ b/tests/testthat/test-metadata.R @@ -155,3 +155,35 @@ test_that("can validate default dt in discrete time models", { "Expected '[[dust2::default_dt()]]' to be the inverse of an integer", fixed = TRUE) }) + + +test_that("validate default dt", { + d <- data.frame(decoration = "dust2::default_dt", + params = I(list(alist(0.5)))) + + expect_equal(parse_metadata_default_dt(d[-1, ], "discrete"), 1) + expect_equal(parse_metadata_default_dt(d, "discrete"), 0.5) + + expect_error( + parse_metadata_default_dt(d, "continuous"), + "Can't use '[[dust::default_dt()]]' with continuous-time systems", + fixed = TRUE) + + d$params[[1]] <- list(quote(a)) + expect_error( + parse_metadata_default_dt(d, "discrete"), + "Expected a numerical argument to '[[dust2::default_dt()]]'", + fixed = TRUE) + + d$params[[1]] <- list(dt = 1) + expect_error( + parse_metadata_default_dt(d, "discrete"), + "Expected a single unnamed argument to '[[dust2::default_dt()]]'", + fixed = TRUE) + + d$params[[1]] <- list(0.37) + expect_error( + parse_metadata_default_dt(d, "discrete"), + "Expected '[[dust2::default_dt()]]' to be the inverse of an integer", + fixed = TRUE) +}) diff --git a/tests/testthat/test-monty.R b/tests/testthat/test-monty.R index b54ef3d8..b8aa1826 100644 --- a/tests/testthat/test-monty.R +++ b/tests/testthat/test-monty.R @@ -175,6 +175,42 @@ test_that("can subset trajectories from model", { }) +test_that("can subset trajectories from deterministic model", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + time_start <- 0 + data <- data.frame(time = c(4, 8, 12, 16), incidence = 1:4) + obj <- dust_unfilter_create(sir(), time_start, data) + packer <- monty::monty_packer( + c("beta", "gamma"), + fixed = list(N = 1000, I0 = 10, exp_noise = 1e6)) + + prior <- monty::monty_dsl({ + beta ~ Exponential(mean = 0.5) + gamma ~ Exponential(mean = 0.5) + }) + + sampler <- monty::monty_sampler_random_walk(diag(2) * c(0.02, 0.02)) + + set.seed(1) + m1 <- dust_likelihood_monty(obj, packer, + save_trajectories = c("I", "cases_inc")) + p1 <- m1 + prior + res1 <- monty::monty_sample(p1, sampler, 13, initial = c(.2, .1), + n_chains = 3) + + set.seed(1) + m2 <- dust_likelihood_monty(obj, packer, save_trajectories = TRUE) + p2 <- m2 + prior + res2 <- monty::monty_sample(p2, sampler, 13, initial = c(.2, .1), + n_chains = 3) + + expect_equal(dim(res1$observations$trajectories), c(2, 4, 13, 3)) + expect_equal(res1$observations$trajectories, + res2$observations$trajectories[c(2, 5), , , ]) + expect_equal(names(res1$observations), "trajectories") +}) + + test_that("can record final state", { pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) time_start <- 0 diff --git a/tests/testthat/test-util.R b/tests/testthat/test-util.R index 46ff0911..0c0e210f 100644 --- a/tests/testthat/test-util.R +++ b/tests/testthat/test-util.R @@ -80,3 +80,12 @@ test_that("fmod works", { expect_equal(fmod(10, 0.1), 0) expect_equal(fmod(10, 0.01), 0) }) + + +test_that("describe ranks", { + expect_equal(rank_description(0), "scalar") + expect_equal(rank_description(1), "vector") + expect_equal(rank_description(2), "matrix") + expect_equal(rank_description(3), "3-dimensional array") + expect_equal(rank_description(300), "300-dimensional array") +}) diff --git a/tests/testthat/test-walk.R b/tests/testthat/test-walk.R index b98b17af..256d20c5 100644 --- a/tests/testthat/test-walk.R +++ b/tests/testthat/test-walk.R @@ -460,6 +460,12 @@ test_that("can validate times", { dust_system_simulate(obj, c(1, 2, 1)), "Values in 'times' must be increasing", fixed = TRUE) + err <- expect_error( + dust_system_simulate(obj, c(10:1)), + "Values in 'times' must be increasing", + fixed = TRUE) + expect_match(conditionMessage(err), + "...and 5 other errors", fixed = TRUE) expect_error( dust_system_simulate(obj, NULL), "Expected 'times' to be numeric")