Skip to content

Commit

Permalink
Issue 270: Correct bug with predict_delay_parameters (#278)
Browse files Browse the repository at this point in the history
* Fix bug with predict_delay_parameters and add test to check

* Test bug fix

* Widen tolerance
  • Loading branch information
athowes authored Sep 4, 2024
1 parent dd22e59 commit 588b55b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 17 deletions.
4 changes: 2 additions & 2 deletions R/postprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ predict_delay_parameters <- function(fit, newdata = NULL, ...) {
# Every brms model has the parameter mu
lp_mu <- brms::get_dpar(pp, dpar = "mu", inv_link = TRUE)
df <- expand.grid(
"index" = seq_len(ncol(lp_mu)),
"draw" = seq_len(nrow(lp_mu))
"draw" = seq_len(nrow(lp_mu)),
"index" = seq_len(ncol(lp_mu))
)
df[["mu"]] <- as.vector(lp_mu)
for (dpar in setdiff(names(pp$dpars), "mu")) {
Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,37 @@ sim_obs_gamma <- simulate_gillespie() |>

sim_obs_gamma <-
sim_obs_gamma[sample(seq_len(.N), sample_size, replace = FALSE)]

# Data with a sex difference

meanlog_m <- 2.0
sdlog_m <- 0.3

meanlog_f <- 1.3
sdlog_f <- 0.7

sim_obs_sex <- simulate_gillespie()
sim_obs_sex$sex <- rbinom(n = nrow(sim_obs_sex), size = 1, prob = 0.5)

sim_obs_sex_m <- dplyr::filter(sim_obs_sex, sex == 0) |>
simulate_secondary(
dist = rlnorm,
meanlog = meanlog_m,
sdlog = sdlog_m
)

sim_obs_sex_f <- dplyr::filter(sim_obs_sex, sex == 1) |>
simulate_secondary(
dist = rlnorm,
meanlog = meanlog_f,
sdlog = sdlog_f
)

sim_obs_sex <- dplyr::bind_rows(sim_obs_sex_m, sim_obs_sex_f) |>
dplyr::arrange(case)

sim_obs_sex <- sim_obs_sex |>
observe_process() |>
filter_obs_by_obs_time(obs_time = obs_time)

sim_obs_sex <- sim_obs_sex[sample(seq_len(.N), sample_size, replace = FALSE)]
54 changes: 39 additions & 15 deletions tests/testthat/test-unit-postprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,55 @@ test_that("predict_delay_parameters works with NULL newdata and the latent logno
)
pred <- predict_delay_parameters(fit)
expect_s3_class(pred, "data.table")
expect_named(pred, c("index", "draw", "mu", "sigma", "mean", "sd"))
expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd"))
expect_true(all(pred$mean > 0))
expect_true(all(pred$sd > 0))
expect_equal(length(unique(pred$index)), nrow(prep_obs))
expect_equal(length(unique(pred$draw)), summary(fit)$total_ndraws)
})

test_that("predict_delay_parameters accepts newdata arguments", { # nolint: line_length_linter.
test_that("predict_delay_parameters accepts newdata arguments and prediction by sex recovers underlying parameters", { # nolint: line_length_linter.
skip_on_cran()
set.seed(1)
prep_obs <- as_latent_individual(sim_obs)
fit <- epidist(
data = prep_obs,
prep_obs_sex <- as_latent_individual(sim_obs_sex)
fit_sex <- epidist(
data = prep_obs_sex,
formula = brms::bf(mu ~ 1 + sex, sigma ~ 1 + sex),
seed = 1,
silent = 2,
output_dir = fs::dir_create(tempfile())
silent = 2
)
pred_sex <- predict_delay_parameters(fit_sex, prep_obs_sex)
expect_s3_class(pred_sex, "data.table")
expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd"))
expect_true(all(pred_sex$mean > 0))
expect_true(all(pred_sex$sd > 0))
expect_equal(length(unique(pred_sex$index)), nrow(prep_obs_sex))
expect_equal(length(unique(pred_sex$draw)), summary(fit_sex)$total_ndraws)

pred_sex_summary <- pred_sex |>
dplyr::left_join(
dplyr::select(data.frame(prep_obs_sex), index = row_id, sex),
by = "index"
) |>
dplyr::group_by(sex) |>
dplyr::summarise(
mu = mean(mu),
sigma = mean(sigma)
)

# Correct predictions of M
expect_equal(
as.numeric(pred_sex_summary[1, c("mu", "sigma")]),
c(meanlog_m, sdlog_m),
tolerance = 0.1
)

# Correction predictions of F
expect_equal(
as.numeric(pred_sex_summary[2, c("mu", "sigma")]),
c(meanlog_f, sdlog_f),
tolerance = 0.1
)
n <- 5
pred <- predict_delay_parameters(fit, newdata = prep_obs[1:n, ])
expect_s3_class(pred, "data.table")
expect_named(pred, c("index", "draw", "mu", "sigma", "mean", "sd"))
expect_true(all(pred$mean > 0))
expect_true(all(pred$sd > 0))
expect_equal(length(unique(pred$index)), 5)
expect_equal(length(unique(pred$draw)), summary(fit)$total_ndraws)
})

test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribution parameter data", { # nolint: line_length_linter.
Expand Down

0 comments on commit 588b55b

Please sign in to comment.