Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 3: Not using DT. #4

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions R/fitting-and-postprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
sample_model <- function(model, data, scenario = data.table::data.table(id = 1),
diagnostics = TRUE, ...) {

out <- scenario |>
copy()
out <- data.table::copy(scenario)

# Setup failure tolerant model fitting
fit_model <- function(model, data, ...) {
Expand All @@ -16,12 +15,10 @@
fit <- safe_fit_model(model, data, ...)

if (!is.null(fit$error)) {
out <- out |>
DT(, error := list(fit$error[[1]]))
out[, error := list(fit$error[[1]])]

Check warning on line 18 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=18,col=11,[object_usage_linter] no visible binding for global variable 'error'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an annoying side effect of linting. Need to add error as a global var in utils to get rid of this flag.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we can update the linter to ignore global var checks (but then Rmd check will still flag this anyway)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively the env features we use elsewhere may help here?

diagnostics <- FALSE
}else {
out <- out |>
DT(, fit := list(fit$result))
out[, fit := list(fit$result)]
fit <- fit$result
}

Expand All @@ -40,12 +37,12 @@
per_divergent_transitions = sum(diag$divergent__) / nrow(diag),
max_treedepth = max(diag$treedepth__)
)
diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)]

Check warning on line 40 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=40,col=19,[object_usage_linter] no visible binding for global variable 'no_at_max_treedepth'

Check warning on line 40 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=40,col=66,[object_usage_linter] no visible binding for global variable 'max_treedepth'
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)]

Check warning on line 41 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=41,col=19,[object_usage_linter] no visible binding for global variable 'per_at_max_treedepth'

Check warning on line 41 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=41,col=43,[object_usage_linter] no visible binding for global variable 'no_at_max_treedepth'
out <- cbind(out, diagnostics)

timing <- round(fit$time()$total, 1)
out[, run_time := timing]

Check warning on line 45 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=45,col=11,[object_usage_linter] no visible binding for global variable 'run_time'
}
return(out[])
}
Expand All @@ -57,8 +54,7 @@
diagnostics = TRUE, ...
) {

out <- scenario |>
copy()
out <- data.table::copy(scenario)

# Setup failure tolerant model fitting
fit_model <- function(model, data, ...) {
Expand All @@ -72,12 +68,10 @@
fit <- safe_fit_model(model, data, ...)

if (!is.null(fit$error)) {
out <- out |>
DT(, error := list(fit$error[[1]]))
out[, error := list(fit$error[[1]])]

Check warning on line 71 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=71,col=11,[object_usage_linter] no visible binding for global variable 'error'
diagnostics <- FALSE
}else {
out <- out |>
DT(, fit := list(fit$result))
out[, fit := list(fit$result)]
fit <- fit$result
}

Expand Down Expand Up @@ -110,8 +104,8 @@
per_divergent_transitions = sum(diag$divergent__) / nrow(diag),
max_treedepth = max(diag$treedepth__)
)
diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)]

Check warning on line 107 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=107,col=19,[object_usage_linter] no visible binding for global variable 'no_at_max_treedepth'

Check warning on line 107 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=107,col=66,[object_usage_linter] no visible binding for global variable 'max_treedepth'
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)]

Check warning on line 108 in R/fitting-and-postprocessing.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/fitting-and-postprocessing.R,line=108,col=19,[object_usage_linter] no visible binding for global variable 'per_at_max_treedepth'
out <- cbind(out, diagnostics)

timing <- round(max(fit$metadata()$time$total), 1)
Expand All @@ -123,11 +117,12 @@
#' Add natural scale summary parameters for a lognormal distribution
#' @export
add_natural_scale_mean_sd <- function(dt) {
nat_dt <- dt |>
data.table::DT(, mean := exp(meanlog + sdlog ^ 2 / 2)) |>
data.table::DT(,
sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1)
)
nat_dt <- data.table::copy(dt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the original (i.e. above) we aren't making a copy but I think its fine either way (though add might imply no copy?)


nat_dt <- nat_dt[,mean := exp(meanlog + sdlog ^ 2 / 2)]

nat_dt <- nat_dt[,sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1)]

return(nat_dt[])
}

Expand Down Expand Up @@ -186,8 +181,7 @@
)
}

draws <- draws |>
data.table::setDT()
draws <- data.table::setDT(draws)

data.table::setnames(
draws, c("refp_mean_int[1]", "refp_sd_int[1]"), c("meanlog", "sdlog"),
Expand All @@ -207,10 +201,11 @@
#' Primary event bias correction
#' @export
primary_censoring_bias_correction <- function(draws) {
draws <- data.table::copy(draws) |>
DT(, mean := mean - runif(.N, min = 0, max = 1)) |>
DT(, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))) |>
DT(, sdlog := sqrt(log(1 + (sd^2 / mean^2))))
draws <- data.table::copy(draws)
draws[, mean := mean - runif(.N, min = 0, max = 1)]
draws[, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))]
draw[, sdlog := sqrt(log(1 + (sd^2 / mean^2)))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
draw[, sdlog := sqrt(log(1 + (sd^2 / mean^2)))]
draws[, sdlog := sqrt(log(1 + (sd^2 / mean^2)))]


return(draws[])
}

Expand All @@ -234,7 +229,8 @@
by = by
)

draws <- draws[, rel_value := value / true_value]
draws[, rel_value := value / true_value]

return(draws[])
}

Expand Down Expand Up @@ -289,9 +285,11 @@
if (missing(variable)) {
stop("variable must be specified")
}
summarised_draws <- draws |>
copy() |>
DT(, value := variable, env = list(variable = variable)) |>
summarise_draws(sf = sf, by = by)
summarised_draws <- data.table::copy(draws)

summarised_draws[, value := variable, env = list(variable = variable)]

summarised_draws <- summarise_draws(summarised_draws, sf = sf, by = by)

return(summarised_draws[])
}
}
69 changes: 30 additions & 39 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ naive_delay <- function(formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data,
filtered_naive_delay <- function(
formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data, fn = brms::brm,
family = "lognormal", truncation = 10, ...) {
data <- data |>
data.table::as.data.table() |>
## NEED TO FILTER BASED ON PTIME
DT(ptime_daily <= (obs_at - truncation))

data <- data.table::as.data.table(data)
## NEED TO FILTER BASED ON PTIME
data <- data[ptime_daily <= (obs_at - truncation)]

data <- drop_zero(data)

fn(
Expand Down Expand Up @@ -85,12 +84,11 @@ latent_censoring_adjusted_delay <- function(

stanvars_all <- stanvars_functions + stanvars_parameters + stanvars_prior

data <- data |>
data.table::as.data.table() |>
DT(, id := 1:.N) |>
DT(, pwindow_upr := ptime_upr - ptime_lwr) |>
DT(, swindow_upr := stime_upr - stime_lwr) |>
DT(, delay_central := stime_lwr - ptime_lwr)
data <- data.table::as.data.table(data)
data[, id := 1:.N]
data[, pwindow_upr := ptime_upr - ptime_lwr]
data[, swindow_upr := stime_upr - stime_lwr]
data[, delay_central := stime_lwr - ptime_lwr]

if (nrow(data) > 1) {
data <- data[, id := as.factor(id)]
Expand All @@ -111,9 +109,8 @@ filtered_censoring_adjusted_delay <- function(
delay_lwr | cens(censored, delay_upr) ~ 1, sigma ~ 1
), data, fn = brms::brm, family = "lognormal", truncation = 10, ...) {

data <- data |>
data.table::as.data.table() |>
DT(ptime_daily <= (obs_at - truncation))
data <- data.table::as.data.table(data)
data <- data[ptime_daily <= (obs_at - truncation)]

data <- pad_zero(data)

Expand Down Expand Up @@ -201,22 +198,18 @@ latent_truncation_censoring_adjusted_delay <- function(
...
) {

data <- data |>
data.table::as.data.table() |>
DT(, id := 1:.N) |>
DT(, obs_t := obs_at - ptime_lwr) |>
DT(, pwindow_upr := ifelse(
stime_lwr < ptime_upr, ## if overlap
stime_upr - ptime_lwr,
ptime_upr - ptime_lwr
)
) |>
DT(,
woverlap := as.numeric(stime_lwr < ptime_upr)
) |>
DT(, swindow_upr := stime_upr - stime_lwr) |>
DT(, delay_central := stime_lwr - ptime_lwr) |>
DT(, row_id := 1:.N)
data <- data.table::as.data.table(data)
data[, id := 1:.N]
data[, obs_t := obs_at - ptime_lwr]
data[, pwindow_upr := ifelse(
stime_lwr < ptime_upr, ## if overlap
stime_upr - ptime_lwr,
ptime_upr - ptime_lwr
)]
data[, woverlap := as.numeric(stime_lwr < ptime_upr)]
data[, swindow_upr := stime_upr - stime_lwr]
data[, delay_central := stime_lwr - ptime_lwr]
data[, row_id := 1:.N]

if (nrow(data) > 1) {
data <- data[, id := as.factor(id)]
Expand Down Expand Up @@ -323,9 +316,8 @@ dynamical_censoring_adjusted_delay <- function(
)
}
cols <- colnames(data)[map_lgl(data, is.integer)]
data <- data |>
data.table::as.data.table() |>
DT(, (cols) := lapply(.SD, as.double), .SDcols = cols)
data <- data.table::as.data.table(data)
data[, (cols) := lapply(.SD, as.double), .SDcols = cols]

data <- drop_zero(data)
## need to do this because lognormal doesn't like zero
Expand Down Expand Up @@ -433,12 +425,11 @@ epinowcast_delay <- function(formula = ~ 1, data, by = c(),
"epinowcast is not installed. Please install it to use this function"
)
}
data_as_counts <- data |>
data.table::as.data.table() |>
DT(, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)) |>
DT(order(ptime_daily, stime_daily)) |>
DT(, reference_date := as.Date("2000-01-01") + ptime_daily) |>
DT(, report_date := as.Date("2000-01-01") + stime_daily)
data_as_counts <- data.table::as.data.table(data)
data_as_counts <- data_as_counts[, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)]
data_as_counts <- data_as_counts[order(ptime_daily, stime_daily)]
data_as_counts[, reference_date := as.Date("2000-01-01") + ptime_daily]
data_as_counts[, report_date := as.Date("2000-01-01") + stime_daily]

# Actual largest observerable delay
preprocess_delay <- min(
Expand Down
101 changes: 47 additions & 54 deletions R/observe.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
#' Observation process for primary and secondary events
#' @export
observe_process <- function(linelist) {
clinelist <- linelist |>
data.table::copy() |>
DT(, ptime_daily := floor(ptime)) |>
DT(, ptime_lwr := ptime_daily) |>
DT(, ptime_upr := ptime_daily + 1) |>
# How the second event would be recorded in the data
DT(, stime_daily := floor(stime)) |>
DT(, stime_lwr := stime_daily) |>
DT(, stime_upr := stime_daily + 1) |>
# How would we observe the delay distribution
# previously: delay_daily=floor(delay)
DT(, delay_daily := stime_daily - ptime_daily) |>
DT(, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))) |>
DT(, delay_upr := delay_daily + 1) |>
# We assume observation time is the ceiling of the maximum delay
DT(, obs_at := stime |>
max() |>
ceiling()
)
clinelist <- data.table::copy(linelist)
clinelist[, ptime_daily := floor(ptime)]
clinelist[, ptime_lwr := ptime_daily]
clinelist[, ptime_upr := ptime_daily + 1]
# How the second event would be recorded in the data
clinelist[, stime_daily := floor(stime)]
clinelist[, stime_lwr := stime_daily]
clinelist[, stime_upr := stime_daily + 1]
# How would we observe the delay distribution
# previously: delay_daily=floor(delay)
clinelist[, delay_daily := stime_daily - ptime_daily]
clinelist[, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))]
clinelist[, delay_upr := delay_daily + 1]
# We assume observation time is the ceiling of the maximum delay
clinelist[, obs_at := stime |>
max() |>
ceiling()]

return(clinelist)
}

#' Filter observations based on a observation time of secondary events
#' @export
filter_obs_by_obs_time <- function(linelist, obs_time) {
truncated_linelist <- linelist |>
data.table::copy() |>
# Update observation time by when we are looking
DT(, obs_at := obs_time) |>
DT(, obs_time := obs_time - ptime) |>
# Assuming truncation at the beginning of the censoring window
DT(,
censored_obs_time := obs_at - ptime_lwr
) |>
DT(, censored := "interval") |>
DT(stime_upr <= obs_at)
truncated_linelist <- data.table::copy(linelist)
truncated_linelist[, obs_at := obs_time]
truncated_linelist[, obs_time := obs_time - ptime]
truncated_linelist[, censored_obs_time := obs_at - ptime_lwr]
truncated_linelist[, censored := "interval"]
truncated_linelist <- truncated_linelist[stime_upr <= obs_at]

return(truncated_linelist)
}

Expand All @@ -47,49 +42,47 @@ filter_obs_by_ptime <- function(linelist, obs_time,
obs_at <- match.arg(obs_at)

pfilt_t <- obs_time
truncated_linelist <- linelist |>
data.table::copy() |>
DT(, censored := "interval") |>
DT(ptime_upr <= pfilt_t)
truncated_linelist <- data.table::copy(linelist)

truncated_linelist[, censored := "interval"]
truncated_linelist <- truncated_linelist[ptime_upr <= pfilt_t]

if (obs_at == "obs_secondary") {
truncated_linelist <- truncated_linelist |>
# Update observation time to be the same as the maximum secondary time
DT(, obs_at := stime_upr)
# Update observation time to be the same as the maximum secondary time
truncated_linelist[, obs_at := stime_upr]
} else if (obs_at == "max_secondary") {
truncated_linelist <- truncated_linelist |>
DT(, obs_at := stime_upr |> max() |> ceiling())
truncated_linelist[, obs_at := stime_upr |> max() |> ceiling()]
}

# make observation time as specified
truncated_linelist <- truncated_linelist |>
DT(, obs_time := obs_at - ptime) |>
# Assuming truncation at the beginning of the censoring window
DT(, censored_obs_time := obs_at - ptime_lwr)
truncated_linelist[, obs_time := obs_at - ptime]
# Assuming truncation at the beginning of the censoring window
truncated_linelist[, censored_obs_time := obs_at - ptime_lwr]

# set observation time to artifial observation time
if (obs_at == "obs_secondary") {
truncated_linelist <- truncated_linelist |>
DT(, obs_at := pfilt_t)
truncated_linelist[, obs_at := pfilt_t]
}
return(truncated_linelist)
}

#' Pad zero observations as unstable in a lognormal distribution
#' @export
pad_zero <- function(data, pad = 1e-3) {
data <- data |>
data.table::copy() |>
# Need upper bound to be greater than lower bound
DT(censored_obs_time == 0, censored_obs_time := 2 * pad) |>
DT(delay_lwr == 0, delay_lwr := pad) |>
DT(delay_daily == 0, delay_daily := pad)
data <- data.table::copy(data)
# Need upper bound to be greater than lower bound
data[censored_obs_time == 0, censored_obs_time := 2 * pad]
data[delay_lwr == 0, delay_lwr := pad]
data[delay_daily == 0, delay_daily := pad]

return(data)
}

#' Drop zero observations as unstable in a lognormal distribution
#' @export
drop_zero <- function(data) {
data <- data |>
data.table::copy() |>
DT(delay_daily != 0)
data <- data.table::copy(data)
data[delay_daily != 0]

return(data)
}
Loading
Loading