Skip to content

Commit

Permalink
Merge pull request #113 from mrc-ide/mrc-5922
Browse files Browse the repository at this point in the history
Allow running simple multiregion model
  • Loading branch information
weshinsley authored Oct 30, 2024
2 parents 7aeba02 + a2a18cd commit 4c20c66
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 45 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ URL: https://github.com/mrc-ide/dust2, https://mrc-ide.github.io/dust2
BugReports: https://github.com/mrc-ide/dust2/issues
Imports:
cli,
monty (>= 0.2.20),
monty (>= 0.2.26),
rlang
LinkingTo:
cpp11,
Expand Down
2 changes: 2 additions & 0 deletions R/interface-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ dust_filter_create <- function(generator, time_start, data,
n_threads = n_threads,
preserve_group_dimension = preserve_group_dimension)

groups <- if (preserve_group_dimension) data$groups else NULL
res <- list2env(
list(inputs = inputs,
initialise = filter_create,
initial_rng_state = filter_rng_state(n_particles, n_groups, seed),
n_particles = n_particles,
n_groups = n_groups,
groups = groups,
deterministic = FALSE,
has_adjoint = FALSE,
generator = generator,
Expand Down
22 changes: 18 additions & 4 deletions R/monty.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,20 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
save_state = FALSE,
save_trajectories = FALSE) {
assert_is(obj, "dust_likelihood")
assert_is(packer, "monty_packer")

is_grouped <- !is.null(obj$groups)
if (is_grouped) {
assert_is(packer, "monty_packer_grouped")
if (!identical(packer$groups(), obj$groups)) {
cli::cli_abort(
c("Groups for 'packer' do not match those of 'obj'",
i = "'obj' has: {squote(obj$groups)}",
x = "'packer' has: {squote(packer$groups())}"),
arg = "packer")
}
} else {
assert_is(packer, "monty_packer")
}

domain <- monty::monty_domain_expand(domain, packer)
save_trajectories <- validate_save_trajectories(save_trajectories)
Expand All @@ -105,7 +118,7 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
is_stochastic = !obj$deterministic,
has_direct_sample = FALSE,
has_gradient = obj$deterministic && obj$has_adjoint,
allow_multiple_parameters = obj$preserve_group_dimension,
allow_multiple_parameters = FALSE,
has_observer = !is.null(observer),
has_parameter_groups = FALSE)

Expand Down Expand Up @@ -144,7 +157,7 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
initial = if (is.null(initial)) NULL else initial(pars),
save_trajectories = save_trajectories$enabled,
index_state = save_trajectories$index)
attr(obj$ptr, "last_density") <- ret
attr(obj$ptr, "last_density") <- if (is_grouped) sum(ret) else ret
attr(obj$ptr, "last_gradient") <- NULL
}
attr(obj$ptr, "last_density")
Expand All @@ -170,12 +183,13 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
env$save_trajectories$index <-
obj$packer_state$subset(save_trajectories$subset)$index
}
dust_likelihood_run(
ll <- dust_likelihood_run(
obj,
pars,
initial = if (is.null(initial)) NULL else initial(pars),
save_trajectories = save_trajectories$enabled,
index_state = save_trajectories$index)
if (is_grouped) sum(ll) else ll
}
}

Expand Down
47 changes: 7 additions & 40 deletions tests/testthat/test-monty.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,38 +78,6 @@ test_that("can avoid errors by converting to impossible density", {
})


test_that("can create wrapper around filter with multiple pars", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)

time_start <- 0
data <- data.frame(group = rep(1:2, each = 4),
time = rep(c(4, 8, 12, 16), 2),
incidence = c(1:4, 2:5))

obj <- dust_filter_create(sir(), time_start, data, n_particles = 100,
seed = 42)
obj1 <- dust_filter_create(sir(), time_start, data[data$group == 1, ],
n_particles = 100, seed = 42)
packer <- monty::monty_packer(
c("beta", "gamma"),
fixed = list(N = 1000, I0 = 10, exp_noise = 1e6))

m <- dust_likelihood_monty(obj, packer)
expect_true(m$properties$allow_multiple_parameters)
expect_true(m$properties$is_stochastic)

m1 <- dust_likelihood_monty(obj1, packer)
expect_false(m1$properties$allow_multiple_parameters)

p <- cbind(c(0.2, 0.1), c(0.25, 0.1))

ll <- monty::monty_model_density(m, p)
ll1 <- monty::monty_model_density(m1, p[, 1])
expect_length(ll, 2)
expect_equal(ll[[1]], ll1)
})


test_that("can get trajectories from model", {
pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)
time_start <- 0
Expand Down Expand Up @@ -292,18 +260,17 @@ test_that("can use names for groups", {
obj1 <- dust_filter_create(sir(), time_start, data1, n_particles = 100,
seed = 42)
obj2 <- dust_filter_create(sir(), time_start, data2, n_particles = 100,
seed = 42)
seed = 42, n_groups = 2)

packer <- monty::monty_packer(
packer <- monty::monty_packer_grouped(
c("a", "b"),
c("beta", "gamma"),
fixed = list(N = 1000, I0 = 10, exp_noise = 1e6))

m1 <- dust_likelihood_monty(obj1, packer)
m2 <- dust_likelihood_monty(obj2, packer)
p2 <- c(0.2, 0.1, 0.25, 0.1)

p <- cbind(c(0.2, 0.1), c(0.25, 0.1))

ll1 <- monty::monty_model_density(m1, p)
ll2 <- monty::monty_model_density(m2, p)
expect_equal(ll1, ll2)
ll1 <- dust_likelihood_run(obj1, unname(packer$unpack(p2)))
ll2 <- monty::monty_model_density(m2, p2)
expect_equal(ll2, sum(ll1))
})

0 comments on commit 4c20c66

Please sign in to comment.