Skip to content

Commit 6df9220

Browse files
authored
Issue 270: Correct bug with predict_delay_parameters (#278)
* Fix bug with predict_delay_parameters and add test to check * Test bug fix * Widen tolerance Former-commit-id: 588b55b Former-commit-id: ec78a6a5552aa4848ac1713c8677ac26cc1a38cc Former-commit-id: 207ac8a43228e4629d4124df543b6db3c286bc69 [formerly 8fbb4bf] Former-commit-id: 860cd5b9986b84cc0932d7b1ed350cc20f21d239
1 parent 9d8b850 commit 6df9220

File tree

3 files changed

+75
-17
lines changed

3 files changed

+75
-17
lines changed

R/postprocess.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ predict_delay_parameters <- function(fit, newdata = NULL, ...) {
1414
# Every brms model has the parameter mu
1515
lp_mu <- brms::get_dpar(pp, dpar = "mu", inv_link = TRUE)
1616
df <- expand.grid(
17-
"index" = seq_len(ncol(lp_mu)),
18-
"draw" = seq_len(nrow(lp_mu))
17+
"draw" = seq_len(nrow(lp_mu)),
18+
"index" = seq_len(ncol(lp_mu))
1919
)
2020
df[["mu"]] <- as.vector(lp_mu)
2121
for (dpar in setdiff(names(pp$dpars), "mu")) {

tests/testthat/setup.R

+34
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,37 @@ sim_obs_gamma <- simulate_gillespie() |>
3838

3939
sim_obs_gamma <-
4040
sim_obs_gamma[sample(seq_len(.N), sample_size, replace = FALSE)]
41+
42+
# Data with a sex difference
43+
44+
meanlog_m <- 2.0
45+
sdlog_m <- 0.3
46+
47+
meanlog_f <- 1.3
48+
sdlog_f <- 0.7
49+
50+
sim_obs_sex <- simulate_gillespie()
51+
sim_obs_sex$sex <- rbinom(n = nrow(sim_obs_sex), size = 1, prob = 0.5)
52+
53+
sim_obs_sex_m <- dplyr::filter(sim_obs_sex, sex == 0) |>
54+
simulate_secondary(
55+
dist = rlnorm,
56+
meanlog = meanlog_m,
57+
sdlog = sdlog_m
58+
)
59+
60+
sim_obs_sex_f <- dplyr::filter(sim_obs_sex, sex == 1) |>
61+
simulate_secondary(
62+
dist = rlnorm,
63+
meanlog = meanlog_f,
64+
sdlog = sdlog_f
65+
)
66+
67+
sim_obs_sex <- dplyr::bind_rows(sim_obs_sex_m, sim_obs_sex_f) |>
68+
dplyr::arrange(case)
69+
70+
sim_obs_sex <- sim_obs_sex |>
71+
observe_process() |>
72+
filter_obs_by_obs_time(obs_time = obs_time)
73+
74+
sim_obs_sex <- sim_obs_sex[sample(seq_len(.N), sample_size, replace = FALSE)]

tests/testthat/test-unit-postprocess.R

+39-15
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,55 @@ test_that("predict_delay_parameters works with NULL newdata and the latent logno
1010
)
1111
pred <- predict_delay_parameters(fit)
1212
expect_s3_class(pred, "data.table")
13-
expect_named(pred, c("index", "draw", "mu", "sigma", "mean", "sd"))
13+
expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd"))
1414
expect_true(all(pred$mean > 0))
1515
expect_true(all(pred$sd > 0))
1616
expect_equal(length(unique(pred$index)), nrow(prep_obs))
1717
expect_equal(length(unique(pred$draw)), summary(fit)$total_ndraws)
1818
})
1919

20-
test_that("predict_delay_parameters accepts newdata arguments", { # nolint: line_length_linter.
20+
test_that("predict_delay_parameters accepts newdata arguments and prediction by sex recovers underlying parameters", { # nolint: line_length_linter.
2121
skip_on_cran()
2222
set.seed(1)
23-
prep_obs <- as_latent_individual(sim_obs)
24-
fit <- epidist(
25-
data = prep_obs,
23+
prep_obs_sex <- as_latent_individual(sim_obs_sex)
24+
fit_sex <- epidist(
25+
data = prep_obs_sex,
26+
formula = brms::bf(mu ~ 1 + sex, sigma ~ 1 + sex),
2627
seed = 1,
27-
silent = 2,
28-
output_dir = fs::dir_create(tempfile())
28+
silent = 2
29+
)
30+
pred_sex <- predict_delay_parameters(fit_sex, prep_obs_sex)
31+
expect_s3_class(pred_sex, "data.table")
32+
expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd"))
33+
expect_true(all(pred_sex$mean > 0))
34+
expect_true(all(pred_sex$sd > 0))
35+
expect_equal(length(unique(pred_sex$index)), nrow(prep_obs_sex))
36+
expect_equal(length(unique(pred_sex$draw)), summary(fit_sex)$total_ndraws)
37+
38+
pred_sex_summary <- pred_sex |>
39+
dplyr::left_join(
40+
dplyr::select(data.frame(prep_obs_sex), index = row_id, sex),
41+
by = "index"
42+
) |>
43+
dplyr::group_by(sex) |>
44+
dplyr::summarise(
45+
mu = mean(mu),
46+
sigma = mean(sigma)
47+
)
48+
49+
# Correct predictions of M
50+
expect_equal(
51+
as.numeric(pred_sex_summary[1, c("mu", "sigma")]),
52+
c(meanlog_m, sdlog_m),
53+
tolerance = 0.1
54+
)
55+
56+
# Correction predictions of F
57+
expect_equal(
58+
as.numeric(pred_sex_summary[2, c("mu", "sigma")]),
59+
c(meanlog_f, sdlog_f),
60+
tolerance = 0.1
2961
)
30-
n <- 5
31-
pred <- predict_delay_parameters(fit, newdata = prep_obs[1:n, ])
32-
expect_s3_class(pred, "data.table")
33-
expect_named(pred, c("index", "draw", "mu", "sigma", "mean", "sd"))
34-
expect_true(all(pred$mean > 0))
35-
expect_true(all(pred$sd > 0))
36-
expect_equal(length(unique(pred$index)), 5)
37-
expect_equal(length(unique(pred$draw)), summary(fit)$total_ndraws)
3862
})
3963

4064
test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribution parameter data", { # nolint: line_length_linter.

0 commit comments

Comments
 (0)