Skip to content

Commit 9d8b850

Browse files
authored
Issue 248: Add prediction and log-likelihood methods for the latent_gamma family (#273)
* Add quick go at these three functions for the latent_gamma family * Lint package * Rename predict tests as latent_lognormal tests * Make generate_examples reproducible and add gamma data * Add unit tests for the latent_gamma methods * Use the fit_gamma rather than fit * Upload gamma fit * Add named documentation to latent_lognormal * Reorder prep and i correctly * Use [s] not [i]! Bug * Rerun of generate test data required * Use scale parameterisation Former-commit-id: dd22e59 Former-commit-id: a2cd30a3f99b55572aeb286351078f133d2cf630 Former-commit-id: 3789a48cc40ec22f4128b208318686e7de538a27 [formerly 878955d] Former-commit-id: 42ab0638f573e74ff809fc46db450a3cd86b9b29
1 parent 863b656 commit 9d8b850

22 files changed

+324
-9
lines changed

NAMESPACE

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export(filter_obs_by_ptime)
4242
export(is_latent_individual)
4343
export(linelist_to_cases)
4444
export(linelist_to_counts)
45+
export(log_lik_latent_gamma)
4546
export(log_lik_latent_lognormal)
4647
export(make_relative_to_truth)
4748
export(observe_process)
@@ -53,7 +54,9 @@ export(plot_empirical_delay)
5354
export(plot_mean_posterior_pred)
5455
export(plot_recovery)
5556
export(plot_relative_recovery)
57+
export(posterior_epred_latent_gamma)
5658
export(posterior_epred_latent_lognormal)
59+
export(posterior_predict_latent_gamma)
5760
export(posterior_predict_latent_lognormal)
5861
export(predict_delay_parameters)
5962
export(predict_dpar)

R/latent_gamma.R

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#' Draws from the posterior predictive distribution of the `latent_gamma` family
2+
#'
3+
#' See [brms::posterior_predict()].
4+
#'
5+
#' @param i The index of the observation to predict
6+
#' @param prep The result of a call to [brms::posterior_predict()]
7+
#' @param ... Additional arguments
8+
#' @family postprocess
9+
#' @autoglobal
10+
#' @export
11+
posterior_predict_latent_gamma <- function(i, prep, ...) { # nolint: object_length_linter
12+
mu <- brms::get_dpar(prep, "mu", i = i)
13+
shape <- brms::get_dpar(prep, "shape", i = i)
14+
15+
obs_t <- prep$data$vreal1[i]
16+
pwindow_width <- prep$data$vreal2[i]
17+
swindow_width <- prep$data$vreal3[i]
18+
19+
.predict <- function(s) {
20+
d_censored <- obs_t + 1
21+
# while loop to impose the truncation
22+
while (d_censored > obs_t) {
23+
p_latent <- runif(1, 0, 1) * pwindow_width
24+
d_latent <- rgamma(1, shape = shape[s], scale = mu[s] / shape[s])
25+
s_latent <- p_latent + d_latent
26+
p_censored <- floor_mult(p_latent, pwindow_width)
27+
s_censored <- floor_mult(s_latent, swindow_width)
28+
d_censored <- s_censored - p_censored
29+
}
30+
return(d_censored)
31+
}
32+
33+
# Within brms this is a helper function called rblapply
34+
do.call(rbind, lapply(seq_len(prep$ndraws), .predict))
35+
}
36+
37+
#' Draws from the expected value of the posterior predictive distribution of the
38+
#' `latent_gamma` family
39+
#'
40+
#' See [brms::posterior_epred()].
41+
#'
42+
#' @param prep The result of a call to [`brms::prepare_predictions`]
43+
#' @family postprocess
44+
#' @autoglobal
45+
#' @export
46+
posterior_epred_latent_gamma <- function(prep) { # nolint: object_length_linter
47+
mu <- brms::get_dpar(prep, "mu")
48+
mu
49+
}
50+
51+
#' Calculate the pointwise log likelihood of the `latent_gamma` family
52+
#'
53+
#' See [brms::log_lik()].
54+
#'
55+
#' @param i The index of the observation to calculate the log likelihood of
56+
#' @param prep The result of a call to [brms::prepare_predictions()]
57+
#' @family postprocess
58+
#' @autoglobal
59+
#' @export
60+
log_lik_latent_gamma <- function(i, prep) {
61+
mu <- brms::get_dpar(prep, "mu", i = i)
62+
shape <- brms::get_dpar(prep, "shape", i = i)
63+
y <- prep$data$Y[i]
64+
obs_t <- prep$data$vreal1[i]
65+
pwindow_width <- prep$data$vreal2[i]
66+
swindow_width <- prep$data$vreal3[i]
67+
68+
swindow_raw <- runif(prep$ndraws)
69+
pwindow_raw <- runif(prep$ndraws)
70+
71+
swindow <- swindow_raw * swindow_width
72+
73+
# For no overlap calculate as usual, for overlap ensure pwindow < swindow
74+
if (i %in% prep$data$noverlap) {
75+
pwindow <- pwindow_raw * pwindow_width
76+
} else {
77+
pwindow <- pwindow_raw * swindow
78+
}
79+
80+
d <- y - pwindow + swindow
81+
obs_time <- obs_t - pwindow
82+
lpdf <- dgamma(d, shape = shape, scale = mu / shape, log = TRUE)
83+
lcdf <- pgamma(obs_time, shape = shape, scale = mu / shape, log.p = TRUE)
84+
return(lpdf - lcdf)
85+
}

R/latent_lognormal.R

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#' Draws from the posterior predictive distribution
1+
#' Draws from the posterior predictive distribution of the `latent_lognormal`
2+
#' family
23
#'
34
#' See [brms::posterior_predict()].
45
#'
@@ -34,7 +35,8 @@ posterior_predict_latent_lognormal <- function(i, prep, ...) { # nolint: object_
3435
do.call(rbind, lapply(seq_len(prep$ndraws), .predict))
3536
}
3637

37-
#' Draws from the expected value of the posterior predictive distribution
38+
#' Draws from the expected value of the posterior predictive distribution of the
39+
#' `latent_gamma` family
3840
#'
3941
#' See [brms::posterior_epred()].
4042
#'
@@ -48,7 +50,7 @@ posterior_epred_latent_lognormal <- function(prep) { # nolint: object_length_lin
4850
exp(mu + sigma^2 / 2)
4951
}
5052

51-
#' Calculate the pointwise log likelihood
53+
#' Calculate the pointwise log likelihood of the `latent_gamma` family
5254
#'
5355
#' See [brms::log_lik()].
5456
#'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ed630da445eec989cc02d70f845a0ec5272a3d82

inst/generate_examples.R

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
source("tests/testthat/setup.R")
2+
set.seed(1)
23
prep_obs <- as_latent_individual(sim_obs)
3-
fit <- epidist(prep_obs)
4-
saveRDS(fit, "inst/extdata/fit.rds")
4+
fit <- epidist(prep_obs, seed = 1)
5+
prep_obs_gamma <- as_latent_individual(sim_obs_gamma)
6+
fit_gamma <- epidist(
7+
prep_obs_gamma,
8+
family = stats::Gamma(link = "log"),
9+
seed = 1
10+
)
11+
saveRDS(fit_gamma, "inst/extdata/fit_gamma.rds")

man/add_mean_sd.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/add_mean_sd.default.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/add_mean_sd.gamma_samples.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/add_mean_sd.lognormal_samples.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/draws_to_long.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/log_lik_latent_gamma.Rd

+34
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/log_lik_latent_lognormal.Rd

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/make_relative_to_truth.Rd

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/posterior_epred_latent_gamma.Rd

+33
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/posterior_epred_latent_lognormal.Rd

+5-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/posterior_predict_latent_gamma.Rd

+36
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/posterior_predict_latent_lognormal.Rd

+5-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)