Skip to content

Commit

Permalink
Merge pull request #138 from danforthcenter/dom_edits
Browse files Browse the repository at this point in the history
several edits from working with katie and dom
  • Loading branch information
joshqsumner authored Nov 26, 2024
2 parents c1a851f + dfd3d6e commit 60c03da
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 139 deletions.
6 changes: 3 additions & 3 deletions CRAN-SUBMISSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Version: 1.0.0
Date: 2024-09-03 19:54:37 UTC
SHA: 0fcf3324368c9c706f9ba0fe0337ebd21fd4753a
Version: 1.1.1.0
Date: 2024-11-06 20:17:03 UTC
SHA: 824b226b066d55d579c39081b7876a11b53601d4
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: pcvr
Type: Package
Title: Plant Phenotyping and Bayesian Statistics
Version: 1.1.1.0
Version: 1.1.1.1
Authors@R:
c(person("Josh", "Sumner", email = "jsumner@danforthcenter.org",
role = c("aut", "cre"), comment = c(ORCID = "0000-0002-3399-9063")),
Expand Down
98 changes: 69 additions & 29 deletions R/brmPlot.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#' although prediction using splines outside of the observed range is not necessarily reliable.
#' @param facetGroups logical, should groups be separated in facets? Defaults to TRUE.
#' @param hierarchy_value If a hierarchical model is being plotted, what value should the
#' hiearchical predictor be? If left NULL (the default) the mean value is used.
#' hiearchical predictor be? If left NULL (the default) the mean value is used. If this is >1L
#' then the x axis will use the hierarchical variable from the model at the mean of the timeRange
#' (mean of x values in the model if timeRange is not specified).
#' @param vir_option Viridis color scale to use for plotting credible intervals. Defaults to "plasma".
#' @keywords growth-curve brms
#' @import ggplot2
Expand Down Expand Up @@ -122,6 +124,42 @@ brmPlot <- function(fit, form, df = NULL, groups = NULL, timeRange = NULL, facet
ggplot2::labs(fill = "Credible\nInterval")
return(p)
}

#' @keywords internal
#' @noRd

.brmLongitudinalPlotSetup <- function(fitData, timeRange, x,
hierarchy_value, group, hierarchical_predictor) {
x_plot_var <- x
x_plot_label <- x
allow_data_lines <- TRUE
if (is.null(timeRange)) {
timeRange <- unique(fitData[[x]])
}
if (length(hierarchy_value) > 1) {
timeRange <- mean(timeRange, na.rm = TRUE)
x_plot_var <- hierarchical_predictor
x_plot_label <- paste0(hierarchical_predictor, " (", x, " = ", round(timeRange, 1), ")")
allow_data_lines <- FALSE
}
if (!all(group %in% colnames(fitData))) {
fitData[, group] <- ""
}
if (!is.null(hierarchical_predictor) && is.null(hierarchy_value)) {
hierarchy_value <- mean(fitData[[hierarchical_predictor]])
}
if (length(hierarchy_value) == 1) {
x_plot_label <- paste0(x, " (", hierarchical_predictor, " = ", round(hierarchy_value, 1), ")")
}
return(
list(
"timeRange" = timeRange, "x_plot_var" = x_plot_var,
"x_plot_label" = x_plot_label,
"allow_data_lines" = allow_data_lines, "fitData" = fitData, "hierarchy_value" = hierarchy_value
)
)
}

#' @keywords internal
#' @noRd

Expand All @@ -137,36 +175,36 @@ brmPlot <- function(fit, form, df = NULL, groups = NULL, timeRange = NULL, facet
facetGroups <- .no_dummy_labels(group, facetGroups)
df <- parsed_form$data
probs <- seq(from = 99, to = 1, by = -2) / 100
if (is.null(timeRange)) {
timeRange <- unique(fitData[[x]])
}
if (!all(group %in% colnames(fitData))) {
fitData[, group] <- ""
}
newData <- do.call(
expand.grid,
blp_setup <- .brmLongitudinalPlotSetup(
fitData, timeRange, x,
hierarchy_value, group, hierarchical_predictor
)
timeRange <- blp_setup$timeRange
x_plot_var <- blp_setup$x_plot_var
x_plot_label <- blp_setup$x_plot_label
allow_data_lines <- blp_setup$allow_data_lines
fitData <- blp_setup$fitData
hierarchy_value <- blp_setup$hierarchy_value
newDataArgs <- append(
c(
lapply(group, function(grp) {
unique(fitData[[grp]])
}),
list(
"new_1"
)
),
append(
list(timeRange),
c(
lapply(group, function(grp) {
unique(fitData[[grp]])
}),
list(
"new_1"
)
)
list(hierarchy_value)
)
)
colnames(newData) <- c(x, group, individual)
newDataArgs <- newDataArgs[!unlist(lapply(newDataArgs, is.null))]
newData <- do.call(expand.grid, newDataArgs)
colnames(newData) <- c(group, individual, x, hierarchical_predictor)
if (length(group) > 1 && paste(group, collapse = ".") %in% colnames(fitData)) {
newData[[paste(group, collapse = ".")]] <- interaction(newData[, group])
}
if (!is.null(hierarchical_predictor)) {
if (is.null(hierarchy_value)) {
hierarchy_value <- mean(fitData[[hierarchical_predictor]])
}
newData[[hierarchical_predictor]] <- hierarchy_value
}
predictions <- cbind(newData, predict(fit, newData, probs = probs))

if (!is.null(groups)) {
Expand Down Expand Up @@ -199,7 +237,7 @@ brmPlot <- function(fit, form, df = NULL, groups = NULL, timeRange = NULL, facet
do.call(rbind, lapply(seq(1, 49, 2), function(i) {
min <- paste0("Q", i)
max <- paste0("Q", 100 - i)
iter <- sub[, c(x, group, individual, "Estimate")]
iter <- sub[, c(x, group, individual, "Estimate", hierarchical_predictor)]
iter$q <- round(1 - (c1 * (i - max_obs) + max_prime), 2)
iter$min <- sub[[min]]
iter$max <- sub[[max]]
Expand All @@ -208,9 +246,9 @@ brmPlot <- function(fit, form, df = NULL, groups = NULL, timeRange = NULL, facet
}))
longPreds$plot_group <- as.character(interaction(longPreds[, group]))
#* `Make plot`
p <- ggplot2::ggplot(longPreds, ggplot2::aes(x = .data[[x]], y = .data$Estimate)) +
p <- ggplot2::ggplot(longPreds, ggplot2::aes(x = .data[[x_plot_var]], y = .data$Estimate)) +
facetLayer +
ggplot2::labs(x = x, y = y) +
ggplot2::labs(x = x_plot_label, y = y) +
pcv_theme()
p <- p +
lapply(unique(longPreds$q), function(q) {
Expand All @@ -227,10 +265,12 @@ brmPlot <- function(fit, form, df = NULL, groups = NULL, timeRange = NULL, facet
viridis::scale_fill_viridis(direction = -1, option = vir_option) +
ggplot2::labs(fill = "Credible\nInterval")

if (!is.null(df) && individual != "dummyIndividual") {
if (!is.null(df) && individual != "dummyIndividual" && allow_data_lines) {
df$plot_group <- as.character(interaction(df[, group]))
p <- p + ggplot2::geom_line(
data = df, ggplot2::aes(.data[[x]], .data[[y]],
data = df,
ggplot2::aes(
.data[[x_plot_var]], .data[[y]],
group = interaction(.data[[individual]], .data[["plot_group"]])
),
color = "gray20", linewidth = 0.2
Expand Down
8 changes: 4 additions & 4 deletions R/brmSS.R
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@

#* `Make parameter grouping formulae`

if (!is.null(pars)) {
if (as.logical(length(pars))) {
if (USEGROUP) {
parForm <- as.formula(paste0(paste(pars, collapse = "+"), "~0+", paste(group, collapse = "*")))
} else {
Expand All @@ -291,7 +291,7 @@
}

#* `Combine formulas into brms.formula object`
if (is.null(parForm)) {
if (is.null(pars)) {
NL <- FALSE
} else {
NL <- TRUE
Expand All @@ -308,14 +308,14 @@
#* ***** `Make priors` *****
out[["prior"]] <- .makePriors(priors, pars, df, group, USEGROUP, sigma, family, bayesForm)
#* ***** `Make initializer function` *****
if (!is.null(pars)) {
if (as.logical(length(pars))) {
initFun <- function(pars = "?", nPerChain = 1) {
init <- lapply(pars, function(i) array(rgamma(nPerChain, 1)))
names(init) <- paste0("b_", pars)
init
}
formals(initFun)$pars <- pars
formals(initFun)$nPerChain <- length(unique(interaction(df[, group])))
formals(initFun)$nPerChain <- length(table(df[, group]))
wrapper <- function() {
initFun()
}
Expand Down
36 changes: 30 additions & 6 deletions R/brmSSHelpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @noRd

.makePriors <- function(priors, pars, df, group, USEGROUP, sigma, family, formula) {
if (is.null(priors)) {
if (is.null(priors) || !as.logical(length(pars))) {
prior <- .explicitDefaultPrior(formula, df, family)
return(prior)
}
Expand Down Expand Up @@ -194,7 +194,8 @@
if (!grepl("changePoint|I$", par)) {
paste0("lognormal(log(", priors[[par]], "), 0.25)") # growth parameters are LN
} else {
paste0("student_t(5,", priors[[par]], ", 3)") # changepoints/intercepts are T_5(mu, 3)
# changepoints/intercepts are T_5(mu, mu / 5) by default
paste0("student_t(5,", priors[[par]], ", ", abs(priors[[par]] / 5), ")")
}
})
priorStanStrings <- unlist(priorStanStrings)
Expand Down Expand Up @@ -629,18 +630,41 @@
if (useGroup) {
by <- paste0(", by = ", paste(group, collapse = ".")) # special variable that is made if there are
# multiple groups and a gam involved.
group <- paste0("0 + ", group)
} else {
by <- NULL
group <- "1"
}
if (nTimes < 11) {
k <- paste0(", k = ", nTimes)
} else {
k <- NULL
}

form <- stats::as.formula(paste0(y, " ~ s(", x, by, k, ")"))
pars <- NULL

if (dpar) {
if (int) {
form <- list(
brms::nlf(stats::as.formula(paste0(y, " ~ ", y, "I + ", y, "spline"))),
stats::as.formula(paste0(y, "I ~ ", group)),
stats::as.formula(paste0(y, "spline ~ s(", x, by, k, ")"))
)
pars <- paste0(y, c("I", "spline"))
} else {
form <- stats::as.formula(paste0(y, " ~ s(", x, by, k, ")"))
pars <- NULL
}
} else {
if (int) {
form <- list(
brms::nlf(stats::as.formula(paste0(y, " ~ I + spline"))),
stats::as.formula(paste0("I ~ ", group)),
stats::as.formula(paste0("spline ~ s(", x, by, k, ")"))
)
pars <- c("I", "spline")
} else {
form <- stats::as.formula(paste0(y, " ~ s(", x, by, k, ")"))
pars <- NULL
}
}
return(list(form = form, pars = pars))
}
#' Helper function for brms formulas
Expand Down
4 changes: 3 additions & 1 deletion R/growthPlot.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#' Note that for brms models this is ignored except if used to specify a different viridis color map
#' to use.
#' @param hierarchy_value If a hierarchical model is being plotted, what value should the
#' hiearchical predictor be? If left NULL (the default) the mean value is used.
#' hiearchical predictor be? If left NULL (the default) the mean value is used. If this is >1L
#' then the x axis will use the hierarchical variable from the model at the mean of the timeRange
#' (mean of x values in the model if timeRange is not specified).
#' @keywords growth-curve
#' @importFrom methods is
#' @seealso \link{growthSS} and \link{fitGrowth} for making compatible models, \link{testGrowth}
Expand Down
36 changes: 27 additions & 9 deletions R/growthSim.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#' specifying "model1 + model2", see examples and \code{\link{growthSS}}.
#' Decay can be specified by including "decay" as part of the model such as "logistic decay" or
#' "linear + linear decay". Count data can be specified with the "count: " prefix,
#' similar to using "poisson: model" in \link{growthSS}.
#' similar to using "poisson: model" in \link{growthSS}. Similarly intercepts can be added with the
#' "int_" prefix, in which case an "I" parameter should be specified.
#' While "gam" models are supported by \code{growthSS}
#' they are not simulated by this function.
#' @param n Number of individuals to simulate over time per each group in params
Expand Down Expand Up @@ -115,6 +116,14 @@
#' geom_line(aes(color = group)) +
#' labs(title = "Linear")
#'
#' simdf <- growthSim("int_linear",
#' n = 20, t = 25,
#' params = list("A" = c(1.1, 0.95), I = c(100, 120))
#' )
#' ggplot(simdf, aes(time, y, group = interaction(group, id))) +
#' geom_line(aes(color = group)) +
#' labs(title = "Linear with Intercept")
#'
#' simdf <- growthSim("logarithmic",
#' n = 20, t = 25,
#' params = list("A" = c(2, 1.7))
Expand Down Expand Up @@ -243,6 +252,11 @@ growthSim <- function(
} else {
COUNT <- FALSE
}
int <- FALSE
if (grepl("^int", model)) {
int <- TRUE
model <- trimws(sub("^int_?", "", model))
}
if (is.null(names(params))) {
names(params) <- c(LETTERS[seq_along(params)])
}
Expand Down Expand Up @@ -275,9 +289,9 @@ growthSim <- function(
}
#* decide which internal funciton to use
if (!grepl("\\+", model)) {
out <- .singleGrowthSim(model, n, t, params, noise, D)
out <- .singleGrowthSim(model, n, t, params, noise, D, int)
} else {
out <- .multiGrowthSim(model, n, t, params, noise, D)
out <- .multiGrowthSim(model, n, t, params, noise, D, int)
}
if (COUNT) {
out <- do.call(rbind, lapply(split(out, interaction(out$group, out$id)), function(sub) {
Expand All @@ -293,7 +307,7 @@ growthSim <- function(
#' @keywords internal
#' @noRd

.multiGrowthSim <- function(model, n = 20, t = 25, params = list(), noise = NULL, D = 0) {
.multiGrowthSim <- function(model, n = 20, t = 25, params = list(), noise = NULL, D = 0, int) {
component_models <- trimws(strsplit(model, "\\+")[[1]])

firstModel <- component_models[1]
Expand All @@ -308,7 +322,7 @@ growthSim <- function(
stop("Simulating segmented data requires 'changePointX' parameters as described in growthSS.")
}

df1 <- do.call(rbind, lapply(1:n, function(i) {
df1 <- do.call(rbind, lapply(seq_len(n), function(i) {
firstChangepointsRand <- lapply(firstChangepoints, function(fc) {
round(rnorm(1, fc, firstNoise$change))
})
Expand All @@ -320,7 +334,7 @@ growthSim <- function(
lapply(firstParams, function(l) l[[g]]),
c(sub(paste0(firstModel, "1"), "", names(firstParams)))
),
noise = firstNoise, D
noise = firstNoise, D, int
)
}))
n_df$group <- rep(letters[seq_along(firstChangepointsRand)],
Expand Down Expand Up @@ -363,7 +377,7 @@ growthSim <- function(
lapply(iterParams, function(l) l[[g]]),
c(sub(paste0(iterModelFindParams, u), "", names(iterParams)))
),
noise = iterNoise, D
noise = iterNoise, D, int = FALSE
)
inner_df$group <- letters[g]
inner_df
Expand Down Expand Up @@ -397,7 +411,7 @@ growthSim <- function(
#' @keywords internal
#' @noRd

.singleGrowthSim <- function(model, n = 20, t = 25, params = list(), noise = NULL, D) {
.singleGrowthSim <- function(model, n = 20, t = 25, params = list(), noise = NULL, D, int) {
models <- c(
"logistic", "gompertz", "double logistic", "double gompertz",
"monomolecular", "exponential", "linear", "power law", "frechet", "weibull", "gumbel",
Expand Down Expand Up @@ -427,10 +441,14 @@ growthSim <- function(
out <- do.call(rbind, lapply(seq_along(params[[1]]), function(i) {
pars <- lapply(params, function(p) p[i])
as.data.frame(rbind(do.call(rbind, lapply(1:n, function(e) {
data.frame(
iter_data <- data.frame(
"id" = paste0("id_", e), "group" = letters[i], "time" = 1:t,
"y" = gsid(D = D, 1:t, pars, noise), stringsAsFactors = FALSE
)
if (int) {
iter_data$y <- iter_data$y + rnorm(1, mean = pars[["I"]], sd = noise[["I"]])
}
iter_data
}))))
}))

Expand Down
Loading

0 comments on commit 60c03da

Please sign in to comment.