From 7db53d61f41a2770fc87b2d2b6cff0713f9a3c26 Mon Sep 17 00:00:00 2001 From: Sang Woo Park Date: Thu, 4 Apr 2024 05:33:58 -0400 Subject: [PATCH] Issue 3: Not using DT. (#4) Former-commit-id: 26eb0ceec6e44cce1a0a8243482f174e64afcb24 [formerly cfa856cceeb0f22d605a3130e90096d9853ed6b3] Former-commit-id: b27ad83c48ed189f3829cac561cd29b4f96dfeda --- R/fitting-and-postprocessing.R | 56 +++++++++--------- R/models.R | 69 ++++++++++------------ R/observe.R | 101 +++++++++++++++------------------ R/plot-helpers.R | 47 +++++++-------- R/plot.R | 3 +- R/preprocess.R | 101 +++++++++++++++++---------------- R/simulate.R | 10 ++-- vignettes/epidist.Rmd | 25 ++++---- 8 files changed, 196 insertions(+), 216 deletions(-) diff --git a/R/fitting-and-postprocessing.R b/R/fitting-and-postprocessing.R index bd26eff72..800c7d04d 100644 --- a/R/fitting-and-postprocessing.R +++ b/R/fitting-and-postprocessing.R @@ -3,8 +3,7 @@ sample_model <- function(model, data, scenario = data.table::data.table(id = 1), diagnostics = TRUE, ...) { - out <- scenario |> - copy() + out <- data.table::copy(scenario) # Setup failure tolerant model fitting fit_model <- function(model, data, ...) { @@ -16,12 +15,10 @@ sample_model <- function(model, data, scenario = data.table::data.table(id = 1), fit <- safe_fit_model(model, data, ...) if (!is.null(fit$error)) { - out <- out |> - DT(, error := list(fit$error[[1]])) + out[, error := list(fit$error[[1]])] diagnostics <- FALSE }else { - out <- out |> - DT(, fit := list(fit$result)) + out[, fit := list(fit$result)] fit <- fit$result } @@ -57,8 +54,7 @@ sample_epinowcast_model <- function( diagnostics = TRUE, ... ) { - out <- scenario |> - copy() + out <- data.table::copy(scenario) # Setup failure tolerant model fitting fit_model <- function(model, data, ...) { @@ -72,12 +68,10 @@ sample_epinowcast_model <- function( fit <- safe_fit_model(model, data, ...) if (!is.null(fit$error)) { - out <- out |> - DT(, error := list(fit$error[[1]])) + out[, error := list(fit$error[[1]])] diagnostics <- FALSE }else { - out <- out |> - DT(, fit := list(fit$result)) + out[, fit := list(fit$result)] fit <- fit$result } @@ -123,11 +117,12 @@ sample_epinowcast_model <- function( #' Add natural scale summary parameters for a lognormal distribution #' @export add_natural_scale_mean_sd <- function(dt) { - nat_dt <- dt |> - data.table::DT(, mean := exp(meanlog + sdlog ^ 2 / 2)) |> - data.table::DT(, - sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1) - ) + nat_dt <- data.table::copy(dt) + + nat_dt <- nat_dt[,mean := exp(meanlog + sdlog ^ 2 / 2)] + + nat_dt <- nat_dt[,sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1)] + return(nat_dt[]) } @@ -186,8 +181,7 @@ extract_epinowcast_draws <- function( ) } - draws <- draws |> - data.table::setDT() + draws <- data.table::setDT(draws) data.table::setnames( draws, c("refp_mean_int[1]", "refp_sd_int[1]"), c("meanlog", "sdlog"), @@ -207,10 +201,11 @@ extract_epinowcast_draws <- function( #' Primary event bias correction #' @export primary_censoring_bias_correction <- function(draws) { - draws <- data.table::copy(draws) |> - DT(, mean := mean - runif(.N, min = 0, max = 1)) |> - DT(, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))) |> - DT(, sdlog := sqrt(log(1 + (sd^2 / mean^2)))) + draws <- data.table::copy(draws) + draws[, mean := mean - runif(.N, min = 0, max = 1)] + draws[, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))] + draw[, sdlog := sqrt(log(1 + (sd^2 / mean^2)))] + return(draws[]) } @@ -234,7 +229,8 @@ make_relative_to_truth <- function(draws, secondary_dist, by = "parameter") { by = by ) - draws <- draws[, rel_value := value / true_value] + draws[, rel_value := value / true_value] + return(draws[]) } @@ -289,9 +285,11 @@ summarise_variable <- function(draws, variable, sf = 6, by = c()) { if (missing(variable)) { stop("variable must be specified") } - summarised_draws <- draws |> - copy() |> - DT(, value := variable, env = list(variable = variable)) |> - summarise_draws(sf = sf, by = by) + summarised_draws <- data.table::copy(draws) + + summarised_draws[, value := variable, env = list(variable = variable)] + + summarised_draws <- summarise_draws(summarised_draws, sf = sf, by = by) + return(summarised_draws[]) -} \ No newline at end of file +} diff --git a/R/models.R b/R/models.R index 7a212f7e2..1a23957a0 100644 --- a/R/models.R +++ b/R/models.R @@ -14,11 +14,10 @@ naive_delay <- function(formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data, filtered_naive_delay <- function( formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data, fn = brms::brm, family = "lognormal", truncation = 10, ...) { - data <- data |> - data.table::as.data.table() |> - ## NEED TO FILTER BASED ON PTIME - DT(ptime_daily <= (obs_at - truncation)) - + data <- data.table::as.data.table(data) + ## NEED TO FILTER BASED ON PTIME + data <- data[ptime_daily <= (obs_at - truncation)] + data <- drop_zero(data) fn( @@ -85,12 +84,11 @@ latent_censoring_adjusted_delay <- function( stanvars_all <- stanvars_functions + stanvars_parameters + stanvars_prior - data <- data |> - data.table::as.data.table() |> - DT(, id := 1:.N) |> - DT(, pwindow_upr := ptime_upr - ptime_lwr) |> - DT(, swindow_upr := stime_upr - stime_lwr) |> - DT(, delay_central := stime_lwr - ptime_lwr) + data <- data.table::as.data.table(data) + data[, id := 1:.N] + data[, pwindow_upr := ptime_upr - ptime_lwr] + data[, swindow_upr := stime_upr - stime_lwr] + data[, delay_central := stime_lwr - ptime_lwr] if (nrow(data) > 1) { data <- data[, id := as.factor(id)] @@ -111,9 +109,8 @@ filtered_censoring_adjusted_delay <- function( delay_lwr | cens(censored, delay_upr) ~ 1, sigma ~ 1 ), data, fn = brms::brm, family = "lognormal", truncation = 10, ...) { - data <- data |> - data.table::as.data.table() |> - DT(ptime_daily <= (obs_at - truncation)) + data <- data.table::as.data.table(data) + data <- data[ptime_daily <= (obs_at - truncation)] data <- pad_zero(data) @@ -201,22 +198,18 @@ latent_truncation_censoring_adjusted_delay <- function( ... ) { - data <- data |> - data.table::as.data.table() |> - DT(, id := 1:.N) |> - DT(, obs_t := obs_at - ptime_lwr) |> - DT(, pwindow_upr := ifelse( - stime_lwr < ptime_upr, ## if overlap - stime_upr - ptime_lwr, - ptime_upr - ptime_lwr - ) - ) |> - DT(, - woverlap := as.numeric(stime_lwr < ptime_upr) - ) |> - DT(, swindow_upr := stime_upr - stime_lwr) |> - DT(, delay_central := stime_lwr - ptime_lwr) |> - DT(, row_id := 1:.N) + data <- data.table::as.data.table(data) + data[, id := 1:.N] + data[, obs_t := obs_at - ptime_lwr] + data[, pwindow_upr := ifelse( + stime_lwr < ptime_upr, ## if overlap + stime_upr - ptime_lwr, + ptime_upr - ptime_lwr + )] + data[, woverlap := as.numeric(stime_lwr < ptime_upr)] + data[, swindow_upr := stime_upr - stime_lwr] + data[, delay_central := stime_lwr - ptime_lwr] + data[, row_id := 1:.N] if (nrow(data) > 1) { data <- data[, id := as.factor(id)] @@ -323,9 +316,8 @@ dynamical_censoring_adjusted_delay <- function( ) } cols <- colnames(data)[map_lgl(data, is.integer)] - data <- data |> - data.table::as.data.table() |> - DT(, (cols) := lapply(.SD, as.double), .SDcols = cols) + data <- data.table::as.data.table(data) + data[, (cols) := lapply(.SD, as.double), .SDcols = cols] data <- drop_zero(data) ## need to do this because lognormal doesn't like zero @@ -433,12 +425,11 @@ epinowcast_delay <- function(formula = ~ 1, data, by = c(), "epinowcast is not installed. Please install it to use this function" ) } - data_as_counts <- data |> - data.table::as.data.table() |> - DT(, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)) |> - DT(order(ptime_daily, stime_daily)) |> - DT(, reference_date := as.Date("2000-01-01") + ptime_daily) |> - DT(, report_date := as.Date("2000-01-01") + stime_daily) + data_as_counts <- data.table::as.data.table(data) + data_as_counts <- data_as_counts[, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)] + data_as_counts <- data_as_counts[order(ptime_daily, stime_daily)] + data_as_counts[, reference_date := as.Date("2000-01-01") + ptime_daily] + data_as_counts[, report_date := as.Date("2000-01-01") + stime_daily] # Actual largest observerable delay preprocess_delay <- min( diff --git a/R/observe.R b/R/observe.R index 663f72915..28f751f23 100644 --- a/R/observe.R +++ b/R/observe.R @@ -1,42 +1,37 @@ #' Observation process for primary and secondary events #' @export observe_process <- function(linelist) { - clinelist <- linelist |> - data.table::copy() |> - DT(, ptime_daily := floor(ptime)) |> - DT(, ptime_lwr := ptime_daily) |> - DT(, ptime_upr := ptime_daily + 1) |> - # How the second event would be recorded in the data - DT(, stime_daily := floor(stime)) |> - DT(, stime_lwr := stime_daily) |> - DT(, stime_upr := stime_daily + 1) |> - # How would we observe the delay distribution - # previously: delay_daily=floor(delay) - DT(, delay_daily := stime_daily - ptime_daily) |> - DT(, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))) |> - DT(, delay_upr := delay_daily + 1) |> - # We assume observation time is the ceiling of the maximum delay - DT(, obs_at := stime |> - max() |> - ceiling() - ) + clinelist <- data.table::copy(linelist) + clinelist[, ptime_daily := floor(ptime)] + clinelist[, ptime_lwr := ptime_daily] + clinelist[, ptime_upr := ptime_daily + 1] + # How the second event would be recorded in the data + clinelist[, stime_daily := floor(stime)] + clinelist[, stime_lwr := stime_daily] + clinelist[, stime_upr := stime_daily + 1] + # How would we observe the delay distribution + # previously: delay_daily=floor(delay) + clinelist[, delay_daily := stime_daily - ptime_daily] + clinelist[, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))] + clinelist[, delay_upr := delay_daily + 1] + # We assume observation time is the ceiling of the maximum delay + clinelist[, obs_at := stime |> + max() |> + ceiling()] + return(clinelist) } #' Filter observations based on a observation time of secondary events #' @export filter_obs_by_obs_time <- function(linelist, obs_time) { - truncated_linelist <- linelist |> - data.table::copy() |> - # Update observation time by when we are looking - DT(, obs_at := obs_time) |> - DT(, obs_time := obs_time - ptime) |> - # Assuming truncation at the beginning of the censoring window - DT(, - censored_obs_time := obs_at - ptime_lwr - ) |> - DT(, censored := "interval") |> - DT(stime_upr <= obs_at) + truncated_linelist <- data.table::copy(linelist) + truncated_linelist[, obs_at := obs_time] + truncated_linelist[, obs_time := obs_time - ptime] + truncated_linelist[, censored_obs_time := obs_at - ptime_lwr] + truncated_linelist[, censored := "interval"] + truncated_linelist <- truncated_linelist[stime_upr <= obs_at] + return(truncated_linelist) } @@ -47,30 +42,26 @@ filter_obs_by_ptime <- function(linelist, obs_time, obs_at <- match.arg(obs_at) pfilt_t <- obs_time - truncated_linelist <- linelist |> - data.table::copy() |> - DT(, censored := "interval") |> - DT(ptime_upr <= pfilt_t) + truncated_linelist <- data.table::copy(linelist) + + truncated_linelist[, censored := "interval"] + truncated_linelist <- truncated_linelist[ptime_upr <= pfilt_t] if (obs_at == "obs_secondary") { - truncated_linelist <- truncated_linelist |> - # Update observation time to be the same as the maximum secondary time - DT(, obs_at := stime_upr) + # Update observation time to be the same as the maximum secondary time + truncated_linelist[, obs_at := stime_upr] } else if (obs_at == "max_secondary") { - truncated_linelist <- truncated_linelist |> - DT(, obs_at := stime_upr |> max() |> ceiling()) + truncated_linelist[, obs_at := stime_upr |> max() |> ceiling()] } # make observation time as specified - truncated_linelist <- truncated_linelist |> - DT(, obs_time := obs_at - ptime) |> - # Assuming truncation at the beginning of the censoring window - DT(, censored_obs_time := obs_at - ptime_lwr) + truncated_linelist[, obs_time := obs_at - ptime] + # Assuming truncation at the beginning of the censoring window + truncated_linelist[, censored_obs_time := obs_at - ptime_lwr] # set observation time to artifial observation time if (obs_at == "obs_secondary") { - truncated_linelist <- truncated_linelist |> - DT(, obs_at := pfilt_t) + truncated_linelist[, obs_at := pfilt_t] } return(truncated_linelist) } @@ -78,18 +69,20 @@ filter_obs_by_ptime <- function(linelist, obs_time, #' Pad zero observations as unstable in a lognormal distribution #' @export pad_zero <- function(data, pad = 1e-3) { - data <- data |> - data.table::copy() |> - # Need upper bound to be greater than lower bound - DT(censored_obs_time == 0, censored_obs_time := 2 * pad) |> - DT(delay_lwr == 0, delay_lwr := pad) |> - DT(delay_daily == 0, delay_daily := pad) + data <- data.table::copy(data) + # Need upper bound to be greater than lower bound + data[censored_obs_time == 0, censored_obs_time := 2 * pad] + data[delay_lwr == 0, delay_lwr := pad] + data[delay_daily == 0, delay_daily := pad] + + return(data) } #' Drop zero observations as unstable in a lognormal distribution #' @export drop_zero <- function(data) { - data <- data |> - data.table::copy() |> - DT(delay_daily != 0) + data <- data.table::copy(data) + data[delay_daily != 0] + + return(data) } diff --git a/R/plot-helpers.R b/R/plot-helpers.R index 5e1f89dfc..437392aab 100644 --- a/R/plot-helpers.R +++ b/R/plot-helpers.R @@ -4,22 +4,20 @@ calculate_cohort_mean <- function(data, type = c("cohort", "cumulative"), by = c(), obs_at) { type <- match.arg(type) out <- copy(data) - - out <- out |> - DT(, .( - mean = mean(delay_daily), n = .N), - by = c("ptime_daily", by) - ) |> - DT(order(rank(ptime_daily))) - + + out <- out[, .( + mean = mean(delay_daily), n = .N), + by = c("ptime_daily", by)] + + out[order(rank(ptime_daily))] + if (type == "cumulative") { out[, mean := cumsum(mean * n) / cumsum(n), by = by] out[, n := cumsum(n), by = by] } if (!missing(obs_at)) { - out <- out |> - DT(, ptime_daily := ptime_daily - obs_at) + out[, ptime_daily := ptime_daily - obs_at] } return(out[]) @@ -56,21 +54,18 @@ calculate_truncated_means <- function(draws, obs_at, ptime, integrate_for_trunc_mean, otherwise = NA_real_ ) - trunc_mean <- draws |> - copy() |> - DT(, - obs_horizon := list(seq(ptime[1] - obs_at, ptime[2] - obs_at)) - ) |> - DT(, - .(obs_horizon = unlist(obs_horizon)), - by = setdiff(colnames(draws), "obs_horizon") - ) |> - DT(, - trunc_mean := purrr::pmap_dbl( - list(x = obs_horizon, m = meanlog, s = sdlog), - safe_integrate_for_trunc_mean, - .progress = TRUE - ) - ) + trunc_mean <- data.table::copy(draws) + + trunc_mean[, obs_horizon := list(seq(ptime[1] - obs_at, ptime[2] - obs_at))] + trunc_mean <- trunc_mean[, + .(obs_horizon = unlist(obs_horizon)), + by = setdiff(colnames(draws), "obs_horizon")] + trunc_mean[, + trunc_mean := purrr::pmap_dbl( + list(x = obs_horizon, m = meanlog, s = sdlog), + safe_integrate_for_trunc_mean, + .progress = TRUE + )] + return(trunc_mean) } \ No newline at end of file diff --git a/R/plot.R b/R/plot.R index 8970d66b2..c7f2220ca 100644 --- a/R/plot.R +++ b/R/plot.R @@ -28,8 +28,7 @@ plot_relative_recovery <- function(relative_data, alpha = 0.8, #' Plot cases by observation window #' @export plot_cases_by_obs_window <- function(cases) { - cases |> - DT(case_type == "primary") |> + cases[case_type == "primary"] |> ggplot() + aes(x = time, y = cases) + geom_col(aes(fill = factor(obs_at)), alpha = 1, col = "#696767b1") + diff --git a/R/preprocess.R b/R/preprocess.R index 915af12ca..df6c38055 100644 --- a/R/preprocess.R +++ b/R/preprocess.R @@ -2,21 +2,22 @@ #' @export linelist_to_counts <- function(linelist, target_time = "ptime_daily", additional_by = c(), pad_zeros = FALSE) { - cases <- linelist |> - data.table::copy() |> - data.table::DT(, time := get(target_time)) |> - data.table::DT(, .(cases = .N), by = c("time", additional_by)) |> - data.table::DT(order(time)) - + cases <- data.table::copy(linelist) + cases[,time := get(target_time)] + cases <- cases[,.(cases = .N), by = c("time", additional_by)] + cases <- cases[order(time)] + if (pad_zeros) { - cases <- cases |> - merge( - data.table::data.table(time = 1:max(linelist[[target_time]])), - by = "time", - all = TRUE - ) |> - data.table::DT(is.na(cases), cases := 0) + cases <- merge( + cases, + data.table::data.table(time = 1:max(linelist[[target_time]])), + by = "time", + all = TRUE + ) + + cases[is.na(cases), cases := 0] } + return(cases[]) } @@ -43,9 +44,12 @@ linelist_to_cases <- function(linelist) { #' For the observation observed at variable reverse the factor ordering #' @export reverse_obs_at <- function(dt) { - dt |> - DT(, obs_at := factor(obs_at)) |> - DT(, obs_at := factor(obs_at, levels = rev(levels(obs_at)))) + dt_rev <- data.table::copy(dt) + + dt_rev[, obs_at := factor(obs_at)] + dt_rev[, obs_at := factor(obs_at, levels = rev(levels(obs_at)))] + + return(dt_rev) } #' Construct case counts by observation window based on secondary observations @@ -58,17 +62,14 @@ construct_cases_by_obs_window <- function(linelist, windows = c(25, 45), obs_type <- match.arg(obs_type) if (obs_type == "stime") { filter_fn <- function(dt, lw, uw) { - dt |> - filter_obs_by_obs_time(obs_time = uw) |> - data.table::DT(stime > lw) + filter_obs_by_obs_time(dt, obs_time = uw)[stime > lw] } }else { filter_fn <- function(dt, lw, uw) { - dt |> - filter_obs_by_ptime(obs_time = uw) |> - data.table::DT(ptime > lw) + filter_obs_by_ptime(dt, obs_time = uw)[ptime > lw] } } + cases <- purrr::map2( lower_window, upper_window, ~ filter_fn(linelist, .x, .y) ) |> @@ -106,29 +107,27 @@ combine_obs <- function(truncated_obs, obs) { #' Calculate the mean difference between continuous and discrete event time #' @export calculate_censor_delay <- function(truncated_obs, additional_by = c()) { - truncated_obs_psumm <- truncated_obs |> - copy() |> - DT(, ptime_delay := ptime - ptime_daily) |> - DT(, .( - mean = mean(ptime_delay), - lwr = ifelse(length(ptime_delay) > 1, t.test(ptime_delay)[[4]][1], 0), - upr = ifelse(length(ptime_delay) > 1, t.test(ptime_delay)[[4]][2], 1)), - by = c("ptime_daily", additional_by)) |> - DT(, lwr := ifelse(lwr < 0, 0, lwr)) |> - DT(, upr := ifelse(upr > 1, 1, upr)) |> - DT(, type := "ptime") - - truncated_obs_ssumm <- truncated_obs |> - copy() |> - DT(, stime_delay := stime - stime_daily) |> - DT(, .( - mean = mean(stime_delay), - lwr = ifelse(length(stime_delay) > 1, t.test(stime_delay)[[4]][1], 0), - upr = ifelse(length(stime_delay) > 1, t.test(stime_delay)[[4]][2], 1)), - by = c("stime_daily", additional_by)) |> - DT(, lwr := ifelse(lwr < 0, 0, lwr)) |> - DT(, upr := ifelse(upr > 1, 1, upr)) |> - DT(, type := "stime") + truncated_obs_psumm <- data.table::copy(truncated_obs) + truncated_obs_psumm[, ptime_delay := ptime - ptime_daily] + truncated_obs_psumm <- truncated_obs_psumm[, .( + mean = mean(ptime_delay), + lwr = ifelse(length(ptime_delay) > 1, t.test(ptime_delay)[[4]][1], 0), + upr = ifelse(length(ptime_delay) > 1, t.test(ptime_delay)[[4]][2], 1)), + by = c("ptime_daily", additional_by)] + truncated_obs_psumm[, lwr := ifelse(lwr < 0, 0, lwr)] + truncated_obs_psumm[, upr := ifelse(upr > 1, 1, upr)] + truncated_obs_psumm[, type := "ptime"] + + truncated_obs_ssumm <- data.table::copy(truncated_obs) + truncated_obs_ssumm[, stime_delay := stime - stime_daily] + truncated_obs_ssumm <- truncated_obs_ssumm[, .( + mean = mean(stime_delay), + lwr = ifelse(length(stime_delay) > 1, t.test(stime_delay)[[4]][1], 0), + upr = ifelse(length(stime_delay) > 1, t.test(stime_delay)[[4]][2], 1)), + by = c("stime_daily", additional_by)] + truncated_obs_ssumm[, lwr := ifelse(lwr < 0, 0, lwr)] + truncated_obs_ssumm[, upr := ifelse(upr > 1, 1, upr)] + truncated_obs_ssumm[, type := "stime"] names(truncated_obs_psumm)[1] <- names(truncated_obs_ssumm)[1] <- "cohort" @@ -140,8 +139,12 @@ calculate_censor_delay <- function(truncated_obs, additional_by = c()) { #' Convert from event based to incidence based data #' @export event_to_incidence <- function(data, by = c()) { - data |> - DT(, .(cases = .N), by = c("ptime_daily", by)) |> - DT(order(ptime_daily)) |> - setnames(old = c("ptime_daily"), new = c("time")) + dd <- data.table::copy(data) + + dd[, .(cases = .N), by = c("ptime_daily", by)] + dd[order(ptime_daily)] + + setnames(dd, old = c("ptime_daily"), new = c("time")) + + return(dd) } \ No newline at end of file diff --git a/R/simulate.R b/R/simulate.R index 0d5e6ca57..7cfafac70 100644 --- a/R/simulate.R +++ b/R/simulate.R @@ -130,11 +130,11 @@ simulate_gillespie <- function(r = 0.2, #' #' @export simulate_secondary <- function(linelist, dist = rlnorm, ...) { - obs <- linelist |> - data.table::copy() |> - DT(, delay := dist(.N, ...)) |> - # When the second event actually happens - DT(, stime := ptime + delay) + obs <- data.table::copy(linelist) + + obs[, delay := dist(.N, ...)] + obs[, stime := ptime + delay] + return(obs) } diff --git a/vignettes/epidist.Rmd b/vignettes/epidist.Rmd index c16c63e29..5cd8d7475 100644 --- a/vignettes/epidist.Rmd +++ b/vignettes/epidist.Rmd @@ -74,8 +74,8 @@ Observe the outbreak after 25 days and take 100 samples. ```{r observe-data} truncated_obs <- obs |> - filter_obs_by_obs_time(obs_time = 25) |> - DT(sample(1:.N, 200, replace = FALSE)) + filter_obs_by_obs_time(obs_time = 25) +truncated_obs <- truncated_obs[sample(1:.N, 200, replace = FALSE)] ``` Plot primary cases (columns), and secondary cases (dots) by the observation time of their secondary events. This reflects would could be observed in real-time. @@ -179,8 +179,9 @@ epinowcast_fit <- epinowcast_delay( data = truncated_obs, parallel_chains = 4, adapt_delta = 0.95, show_messages = FALSE, refresh = 0, with_epinowcast_output = FALSE ) -epinowcast_draws <- extract_epinowcast_draws(epinowcast_fit) |> - DT(, model := "Joint incidence and forward delay") +epinowcast_draws <- extract_epinowcast_draws(epinowcast_fit) +epinowcast_draws[, model := "Joint incidence and forward delay"] + ``` ### Summarise model posteriors and compare to known truth @@ -206,12 +207,12 @@ Extract and summarise lognormal posterior estimates. draws <- models |> map(extract_lognormal_draws) |> rbindlist(idcol = "model") |> - rbind(epinowcast_draws, use.names = TRUE) |> - DT(, + rbind(epinowcast_draws, use.names = TRUE) + +draws <- draws[, model := factor( model, levels = c("Joint incidence and forward delay", rev(names(models))) - ) - ) + )] summarised_draws <- draws |> draws_to_long() |> @@ -242,11 +243,11 @@ truncated_draws <- draws |> obs_at = max(truncated_obs$stime_daily), ptime = range(truncated_obs$ptime_daily) ) |> - summarise_variable(variable = "trunc_mean", by = c("obs_horizon", "model")) |> - DT(, model := factor( + summarise_variable(variable = "trunc_mean", by = c("obs_horizon", "model")) + +truncated_draws[, model := factor( model, levels = c("Joint incidence and forward delay", rev(names(models))) - ) - ) + )] truncated_draws |> plot_mean_posterior_pred(