Skip to content

Commit

Permalink
Add simulate to the interface
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 22, 2024
1 parent afbc083 commit f8d3268
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 45 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export(dust_model_run_to_time)
export(dust_model_set_state)
export(dust_model_set_state_initial)
export(dust_model_set_time)
export(dust_model_simulate)
export(dust_model_state)
export(dust_model_time)
export(dust_model_update_pars)
Expand Down
35 changes: 34 additions & 1 deletion R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dust_model <- function(name, env = parent.env(parent.frame())) {
"time", "set_time",
"rng_state",
"update_pars",
"run_steps", "run_to_time",
"run_steps", "run_to_time", "simulate",
"reorder")
methods_compare <- "compare_data"
methods <- get_methods(c(methods_core, methods_compare), name)
Expand Down Expand Up @@ -261,6 +261,39 @@ dust_model_run_to_time <- function(model, time) {
}


##' Simulate a model over a series of times, returning an array of
##' output. This output can be quite large, so you may filter states
##' according to some index.
##'
##' @title Simulate model
##'
##' @inheritParams dust_model_state
##'
##' @param times A vector of times. They must be increasing, and the
##' first time must be no greater than the current model time
##' (as reported by [dust_model_time])
##'
##' @param index An optional index of states to extract. If given,
##' then we subset the model state on return. You can use this to
##' return fewer model states than the model ran with, to reorder
##' states, or to name them on exit (names present on the index will
##' be copied into the rownames of the returned array).
##'
##' @return An array with 3 dimensions (state x particle x time) or 4
##' dimensions (state x particle x group x time) for a grouped
##' model.
##'
##' @export
dust_model_simulate <- function(model, times, index = NULL) {
check_is_dust_model(model)
ret <- model$methods$simulate(model$ptr, times, index, model$grouped)
if (!is.null(index) && !is.null(names(index))) {
rownames(ret) <- names(index)
}
ret
}


##' Reorder states within a model. This function is primarily used
##' for debugging and may be removed from the interface if it is not
##' generally useful.
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ reference:
contents:
- dust_model_run_steps
- dust_model_run_to_time
- dust_model_simulate
- subtitle: Other
contents:
- dust_model_reorder
Expand Down
31 changes: 31 additions & 0 deletions man/dust_model_simulate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 33 additions & 21 deletions tests/testthat/test-sir.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,37 +150,49 @@ test_that("can update parameters", {

test_that("can run simulation", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
obj1 <- dust2_cpu_sir_alloc(pars, 0, 1, 10, 0, 42, FALSE)
obj2 <- dust2_cpu_sir_alloc(pars, 0, 1, 10, 0, 42, FALSE)
ptr1 <- obj1[[1]]
ptr2 <- obj2[[1]]
dust2_cpu_sir_set_state_initial(ptr1)
dust2_cpu_sir_set_state_initial(ptr2)

res <- dust2_cpu_sir_simulate(ptr1, 0:20, NULL, FALSE)
obj1 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
obj2 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
dust_model_set_state_initial(obj1)
dust_model_set_state_initial(obj2)

res <- dust_model_simulate(obj1, 0:20)
expect_equal(dim(res), c(5, 10, 21))

expect_equal(res[, , 21], dust2_cpu_sir_state(ptr1, FALSE))
expect_equal(dust2_cpu_sir_time(ptr1), 20)
expect_equal(res[, , 21], dust_model_state(obj1))
expect_equal(dust_model_time(obj1), 20)

expect_equal(res[, , 1], dust2_cpu_sir_state(ptr2, FALSE))
dust2_cpu_sir_run_to_time(ptr2, 10)
expect_equal(res[, , 11], dust2_cpu_sir_state(ptr2, FALSE))
expect_equal(res[, , 1], dust_model_state(obj2))
dust_model_run_to_time(obj2, 10)
expect_equal(res[, , 11], dust_model_state(obj2))
})


test_that("can run simulation with index", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
obj1 <- dust2_cpu_sir_alloc(pars, 0, 1, 10, 0, 42, FALSE)
obj2 <- dust2_cpu_sir_alloc(pars, 0, 1, 10, 0, 42, FALSE)
ptr1 <- obj1[[1]]
ptr2 <- obj2[[1]]
dust2_cpu_sir_set_state_initial(ptr1)
dust2_cpu_sir_set_state_initial(ptr2)
obj1 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
obj2 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
dust_model_set_state_initial(obj1)
dust_model_set_state_initial(obj2)

index <- c(2L, 4L)
res1 <- dust2_cpu_sir_simulate(ptr1, 0:20, index, FALSE)
res2 <- dust2_cpu_sir_simulate(ptr2, 0:20, NULL, FALSE)
res1 <- dust_model_simulate(obj1, 0:20, index)
res2 <- dust_model_simulate(obj2, 0:20)

expect_equal(res1, res2[index, , ])
})


test_that("copy names with index", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
obj1 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
obj2 <- dust_model_create(sir(), pars, n_particles = 10, seed = 42)
dust_model_set_state_initial(obj1)
dust_model_set_state_initial(obj2)

index <- c(I = 2L, cases = 4L)
res1 <- dust_model_simulate(obj1, 0:20, index)
res2 <- dust_model_simulate(obj2, 0:20, unname(index))
expect_equal(dimnames(res1), list(c("I", "cases"), NULL, NULL))

expect_equal(unname(res1), res2)
})
45 changes: 22 additions & 23 deletions tests/testthat/test-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -420,56 +420,55 @@ test_that("time must not be in the past", {

test_that("can simulate walk model", {
pars <- lapply(1:4, function(sd) list(len = 3, sd = sd))
obj1 <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
obj2 <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
ptr1 <- obj1[[1]]
ptr2 <- obj2[[1]]
obj1 <- dust_model_create(walk(), pars, n_particles = 10, n_groups = 4,
seed = 42)
obj2 <- dust_model_create(walk(), pars, n_particles = 10, n_groups = 4,
seed = 42)

res <- dust2_cpu_walk_simulate(ptr1, 0:20, NULL, TRUE)
res <- dust_model_simulate(obj1, 0:20)
expect_equal(dim(res), c(3, 10, 4, 21))

expect_equal(res[, , , 21], dust2_cpu_walk_state(ptr1, TRUE))
expect_equal(dust2_cpu_walk_time(ptr1), 20)
expect_equal(res[, , , 21], dust_model_state(obj1))
expect_equal(dust_model_time(obj1), 20)

expect_equal(res[, , , 1], dust2_cpu_walk_state(ptr2, TRUE))
dust2_cpu_walk_run_to_time(ptr2, 10)
expect_equal(res[, , , 11], dust2_cpu_walk_state(ptr2, TRUE))
expect_equal(res[, , , 1], dust_model_state(obj2))
dust_model_run_to_time(obj2, 10)
expect_equal(res[, , , 11], dust_model_state(obj2))
})


test_that("can simulate walk model with index", {
pars <- lapply(1:4, function(sd) list(len = 6, sd = sd))
obj1 <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
obj2 <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
ptr1 <- obj1[[1]]
ptr2 <- obj2[[1]]
obj1 <- dust_model_create(walk(), pars, n_particles = 10, n_groups = 4,
seed = 42)
obj2 <- dust_model_create(walk(), pars, n_particles = 10, n_groups = 4,
seed = 42)

res1 <- dust2_cpu_walk_simulate(ptr1, 0:20, c(2, 4), TRUE)
res2 <- dust2_cpu_walk_simulate(ptr2, 0:20, NULL, TRUE)
res1 <- dust_model_simulate(obj1, 0:20, c(2, 4))
res2 <- dust_model_simulate(obj2, 0:20)
expect_equal(res1, res2[c(2, 4), , , ])
})


test_that("can validate index values", {
pars <- list(len = 3, sd = 1)
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE)
ptr <- obj[[1]]
obj <- dust_model_create(walk(), pars, n_particles = 10, seed = 42)
expect_error(
dust2_cpu_walk_simulate(ptr, 0:20, c("2", "4"), TRUE),
dust_model_simulate(obj, 0:20, c("2", "4")),
"Expected an integer vector for 'index'")
expect_error(
dust2_cpu_walk_simulate(ptr, 0:20, c(2, 2.9), TRUE),
dust_model_simulate(obj, 0:20, c(2, 2.9)),
"All values of 'index' must be integer-like, but 'index[2]' was not",
fixed = TRUE)
expect_error(
dust2_cpu_walk_simulate(ptr, 0:20, c(2, 20), TRUE),
dust_model_simulate(obj, 0:20, c(2, 20)),
"All values of 'index' must be in [1, 3], but 'index[2]' was 20",
fixed = TRUE)
expect_error(
dust2_cpu_walk_simulate(ptr, 0:20, c(2, 3, -4), TRUE),
dust_model_simulate(obj, 0:20, c(2, 3, -4)),
"All values of 'index' must be in [1, 3], but 'index[3]' was -4",
fixed = TRUE)
expect_error(
dust2_cpu_walk_simulate(ptr, 0:20, integer(), TRUE),
dust_model_simulate(obj, 0:20, integer()),
"'index' must have nonzero length")
})

0 comments on commit f8d3268

Please sign in to comment.