From e274b1f83b0bed65915716dbe4e9a1769f65d2b0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 3 Feb 2025 00:21:39 -0600 Subject: [PATCH 01/21] Refactoring out the "pre-initialized" parameter in the custom sampler --- R/bart.R | 10 +-- R/bcf.R | 12 ++-- R/config.R | 104 ++++++++++++++++++++++++++++ R/forest.R | 31 ++++++++- R/model.R | 7 +- man/Forest.Rd | 26 +++++++ man/ForestModel.Rd | 3 +- man/resetForestModel.Rd | 3 +- vignettes/CustomSamplingRoutine.Rmd | 16 +++++ 9 files changed, 197 insertions(+), 15 deletions(-) create mode 100644 R/config.R diff --git a/R/bart.R b/R/bart.R index 1c814dd3..9ebe44cf 100644 --- a/R/bart.R +++ b/R/bart.R @@ -601,11 +601,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) else init_values_mean_forest <- 0. active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest) + active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, F) } # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + } # Run GFR (warm start) if specified @@ -626,14 +628,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train forest_model_mean$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { @@ -748,14 +750,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train forest_model_mean$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { diff --git a/R/bcf.R b/R/bcf.R index 49fa908a..bd90201a 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -812,7 +812,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_mu$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T ) # Sample variance parameters (if requested) @@ -829,7 +829,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_tau$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T ) # Sample coding parameters (if requested) @@ -874,7 +874,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_variance$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { @@ -1040,7 +1040,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_mu$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F ) # Sample variance parameters (if requested) @@ -1057,7 +1057,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_tau$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F ) # Sample coding parameters (if requested) @@ -1102,7 +1102,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_variance$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { diff --git a/R/config.R b/R/config.R new file mode 100644 index 00000000..4cd6f8aa --- /dev/null +++ b/R/config.R @@ -0,0 +1,104 @@ +#' #' Dataset used to get / set parameters and other model configuration options +#' #' for the "low-level" stochtree interface +#' #' +#' #' @description +#' #' The "low-level" stochtree interface enables a high degreee of sampler +#' #' customization, in which users employ R wrappers around C++ objects +#' #' like ForestDataset, Outcome, CppRng, and ForestModel to run the +#' #' Gibbs sampler of a BART model with custom modifications. +#' #' ModelConfig allows users to specify / query the parameters of a +#' #' tree model they wish to run. +#' +#' ModelConfig <- R6::R6Class( +#' classname = "ModelConfig", +#' cloneable = FALSE, +#' public = list( +#' +#' #' @field data_ptr External pointer to a C++ ModelConfig class +#' data_ptr = NULL, +#' +#' #' @description +#' #' Create a new ForestDataset object. +#' #' @param covariates Matrix of covariates +#' #' @param basis (Optional) Matrix of bases used to define a leaf regression +#' #' @param variance_weights (Optional) Vector of observation-specific variance weights +#' #' @return A new `ForestDataset` object. +#' initialize = function(covariates, basis=NULL, variance_weights=NULL) { +#' self$data_ptr <- create_forest_dataset_cpp() +#' forest_dataset_add_covariates_cpp(self$data_ptr, covariates) +#' if (!is.null(basis)) { +#' forest_dataset_add_basis_cpp(self$data_ptr, basis) +#' } +#' if (!is.null(variance_weights)) { +#' forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) +#' } +#' }, +#' +#' #' @description +#' #' Update basis matrix in a dataset +#' #' @param basis Updated matrix of bases used to define a leaf regression +#' update_basis = function(basis) { +#' stopifnot(self$has_basis()) +#' forest_dataset_update_basis_cpp(self$data_ptr, basis) +#' }, +#' +#' #' @description +#' #' Return number of observations in a `ForestDataset` object +#' #' @return Observation count +#' num_observations = function() { +#' return(dataset_num_rows_cpp(self$data_ptr)) +#' }, +#' +#' #' @description +#' #' Return number of covariates in a `ForestDataset` object +#' #' @return Covariate count +#' num_covariates = function() { +#' return(dataset_num_covariates_cpp(self$data_ptr)) +#' }, +#' +#' #' @description +#' #' Return number of bases in a `ForestDataset` object +#' #' @return Basis count +#' num_basis = function() { +#' return(dataset_num_basis_cpp(self$data_ptr)) +#' }, +#' +#' #' @description +#' #' Whether or not a dataset has a basis matrix +#' #' @return True if basis matrix is loaded, false otherwise +#' has_basis = function() { +#' return(dataset_has_basis_cpp(self$data_ptr)) +#' }, +#' +#' #' @description +#' #' Whether or not a dataset has variance weights +#' #' @return True if variance weights are loaded, false otherwise +#' has_variance_weights = function() { +#' return(dataset_has_variance_weights_cpp(self$data_ptr)) +#' } +#' ), +#' private = list( +#' feature_types = NULL, +#' num_trees = NULL, +#' num_observations = NULL, +#' alpha = NULL, +#' beta = NULL, +#' min_samples_leaf = NULL, +#' max_depth = NULL, +#' ) +#' ) +#' +#' #' Create an model config object +#' #' +#' #' @return `ModelConfig` object +#' #' @export +#' #' +#' #' @examples +#' #' X <- matrix(runif(10*100), ncol = 10) +#' #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) +#' #' config <- createModelConfig(y) +#' createModelConfig <- function(){ +#' return(invisible(( +#' createModelConfig$new() +#' ))) +#' } diff --git a/R/forest.R b/R/forest.R index 224e8c7c..e93957c2 100644 --- a/R/forest.R +++ b/R/forest.R @@ -558,6 +558,9 @@ Forest <- R6::R6Class( #' @field forest_ptr External pointer to a C++ TreeEnsemble class forest_ptr = NULL, + #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. + internal_forest_is_empty = TRUE, + #' @description #' Create a new Forest object. #' @param num_trees Number of trees in the forest @@ -567,6 +570,7 @@ Forest <- R6::R6Class( #' @return A new `Forest` object. initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exponentiated=F) { self$forest_ptr <- active_forest_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) + self$internal_forest_is_empty <- TRUE }, #' @description @@ -610,6 +614,7 @@ Forest <- R6::R6Class( #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. set_root_leaves = function(leaf_value) { stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) # Set leaf values if (length(leaf_value) == 1) { @@ -621,6 +626,8 @@ Forest <- R6::R6Class( } else { stop("leaf_value must be a numeric value or vector of length >= 1") } + + self$internal_forest_is_empty = FALSE }, #' @description @@ -636,12 +643,15 @@ Forest <- R6::R6Class( stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) # Initialize the model initialize_forest_model_active_forest_cpp( dataset$data_ptr, outcome$data_ptr, self$forest_ptr, forest_model$tracker_ptr, leaf_value, leaf_model_int ) + + self$internal_forest_is_empty = FALSE }, #' @description @@ -745,6 +755,23 @@ Forest <- R6::R6Class( #' @return Average maximum depth average_max_depth = function() { return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' When a forest object is created, it is "empty" in the sense that none + #' of its component trees have leaves with values. There are two ways to + #' "initialize" a Forest object. First, the `set_root_leaves()` method + #' simply initializes every tree in the forest to a single node carrying + #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + #' method initializes every tree in the forest to a single node with the + #' same value and also propagates this information through to a ForestModel + #' object, which must be synchronized with a Forest during a forest + #' sampler loop. + #' @return `TRUE` if a Forest has not yet been initialized with a constant + #' root value, `FALSE` otherwise if the forest has already been + #' initialized / grown. + is_empty = function() { + return(self$internal_forest_is_empty) } ) ) @@ -818,6 +845,7 @@ createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exp resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NULL) { if (is.null(forest_samples)) { root_reset_active_forest_cpp(active_forest$forest_ptr) + active_forest$internal_forest_is_empty = TRUE } else { if (is.null(forest_num)) { stop("`forest_num` must be specified if `forest_samples` is provided") @@ -863,11 +891,12 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +#' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( #' forest_dataset, outcome, forest_samples, active_forest, #' rng, feature_types, leaf_model, leaf_scale, variable_weights, #' a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, -#' gfr = FALSE, pre_initialized = TRUE +#' gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/R/model.R b/R/model.R index 5dc1de69..69896864 100644 --- a/R/model.R +++ b/R/model.R @@ -79,7 +79,12 @@ ForestModel <- R6::R6Class( sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, rng, feature_types, leaf_model_int, leaf_model_scale, variable_weights, a_forest, b_forest, global_scale, cutpoint_grid_size = 500, - keep_forest = T, gfr = T, pre_initialized = F) { + keep_forest = T, gfr = T) { + if (active_forest$is_empty()) { + stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") + } + pre_initialized = T + if (gfr) { sample_gfr_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, diff --git a/man/Forest.Rd b/man/Forest.Rd index 075460f4..fd5cc045 100644 --- a/man/Forest.Rd +++ b/man/Forest.Rd @@ -10,6 +10,8 @@ Wrapper around a C++ tree ensemble \if{html}{\out{
}} \describe{ \item{\code{forest_ptr}}{External pointer to a C++ TreeEnsemble class} + +\item{\code{internal_forest_is_empty}}{Whether the forest has not yet been "initialized" such that its \code{predict} function can be called.} } \if{html}{\out{
}} } @@ -32,6 +34,7 @@ Wrapper around a C++ tree ensemble \item \href{#method-Forest-get_forest_split_counts}{\code{Forest$get_forest_split_counts()}} \item \href{#method-Forest-tree_max_depth}{\code{Forest$tree_max_depth()}} \item \href{#method-Forest-average_max_depth}{\code{Forest$average_max_depth()}} +\item \href{#method-Forest-is_empty}{\code{Forest$is_empty()}} } } \if{html}{\out{
}} @@ -361,4 +364,27 @@ Average the maximum depth of each tree in the forest Average maximum depth } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-is_empty}{}}} +\subsection{Method \code{is_empty()}}{ +When a forest object is created, it is "empty" in the sense that none +of its component trees have leaves with values. There are two ways to +"initialize" a Forest object. First, the \code{set_root_leaves()} method +simply initializes every tree in the forest to a single node carrying +the same (user-specified) leaf value. Second, the \code{prepare_for_sampler()} +method initializes every tree in the forest to a single node with the +same value and also propagates this information through to a ForestModel +object, which must be synchronized with a Forest during a forest +sampler loop. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$is_empty()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{TRUE} if a Forest has not yet been initialized with a constant +root value, \code{FALSE} otherwise if the forest has already been +initialized / grown. +} +} } diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 61b688a3..5e0ecdae 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -94,8 +94,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) global_scale, cutpoint_grid_size = 500, keep_forest = T, - gfr = T, - pre_initialized = F + gfr = T )}\if{html}{\out{}} } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index 07f3a8fa..5ed766f7 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -50,11 +50,12 @@ rng <- createCppRNG(1234) forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, leaf_model, leaf_scale, variable_weights, a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, - gfr = FALSE, pre_initialized = TRUE + gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 9ed1725c..a6d08ba7 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -145,6 +145,10 @@ if (leaf_regression) { forest_samples <- createForestSamples(num_trees, 1, T) active_forest <- createForest(num_trees, 1, T) } + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) ``` Prepare to run the sampler @@ -335,6 +339,10 @@ if (leaf_regression) { active_forest <- createForest(num_trees, 1, T) } +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + # Random effects dataset rfx_basis <- as.matrix(rfx_basis) group_ids <- as.integer(group_ids) @@ -570,6 +578,10 @@ if (leaf_regression) { active_forest <- createForest(num_trees, 1, T) } +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + # Random effects dataset rfx_basis <- as.matrix(rfx_basis) group_ids <- as.integer(group_ids) @@ -803,6 +815,10 @@ if (leaf_regression) { forest_samples <- createForestSamples(num_trees, 1, T) active_forest <- createForest(num_trees, 1, T) } + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) ``` Prepare to run the sampler From a372b21ba365793a270bcfcda863d2af2fbe2376 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 3 Feb 2025 00:39:47 -0600 Subject: [PATCH 02/21] Updated vignette --- vignettes/CustomSamplingRoutine.Rmd | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index a6d08ba7..e2c03fea 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -1184,8 +1184,7 @@ if (num_gfr > 0){ forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, - pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample variance parameters (if requested) @@ -1198,8 +1197,7 @@ if (num_gfr > 0){ forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, - pre_initialized = T + 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T ) # Sample adaptive coding parameters @@ -1236,7 +1234,7 @@ if (num_burnin + num_mcmc > 0) { forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T + cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -1247,7 +1245,7 @@ if (num_burnin + num_mcmc > 0) { forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T + cutpoint_grid_size, keep_forest = T, gfr = F ) # Sample coding parameters From 57c7d26fde92a02daf410434d26ea2680c659f09 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 3 Feb 2025 00:55:20 -0600 Subject: [PATCH 03/21] Refactored R / cpp11 interface --- R/cpp11.R | 8 ++++---- R/model.R | 5 ++--- src/cpp11.cpp | 16 ++++++++-------- src/sampler.cpp | 14 ++++++++++---- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index 8ad8ba24..7188e9f7 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -556,12 +556,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { diff --git a/R/model.R b/R/model.R index 69896864..ef3e981a 100644 --- a/R/model.R +++ b/R/model.R @@ -83,21 +83,20 @@ ForestModel <- R6::R6Class( if (active_forest$is_empty()) { stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") } - pre_initialized = T if (gfr) { sample_gfr_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } else { sample_mcmc_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } }, diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 2364da8f..00a3fcbc 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1028,18 +1028,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } @@ -1583,8 +1583,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 17}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 16}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 16}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, diff --git a/src/sampler.cpp b/src/sampler.cpp index 5b5d8afb..4dbe5e13 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -24,9 +24,12 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { @@ -93,9 +96,12 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { From f42460494fb73039884d9383b1063ccc6dfe9943 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 3 Feb 2025 16:48:55 -0600 Subject: [PATCH 04/21] Updated BART and BCF to use modelconfig --- NAMESPACE | 1 + R/bart.R | 75 ++++-- R/bcf.R | 111 ++++++-- R/config.R | 461 ++++++++++++++++++++++++-------- R/model.R | 51 ++-- man/ForestModel.Rd | 31 +-- man/ForestModelConfig.Rd | 467 +++++++++++++++++++++++++++++++++ man/createForestModel.Rd | 32 +-- man/createForestModelConfig.Rd | 67 +++++ 9 files changed, 1075 insertions(+), 221 deletions(-) create mode 100644 man/ForestModelConfig.Rd create mode 100644 man/createForestModelConfig.Rd diff --git a/NAMESPACE b/NAMESPACE index bcb23c2d..3ea7ccd7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -28,6 +28,7 @@ export(createCppRNG) export(createForest) export(createForestDataset) export(createForestModel) +export(createForestModelConfig) export(createForestSamples) export(createOutcome) export(createPreprocessorFromJson) diff --git a/R/bart.R b/R/bart.R index 9ebe44cf..4b1352b9 100644 --- a/R/bart.R +++ b/R/bart.R @@ -541,10 +541,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Sampling data structures feature_types <- as.integer(feature_types) if (include_mean_forest) { - forest_model_mean <- createForestModel(forest_dataset_train, feature_types, num_trees_mean, nrow(X_train), alpha_mean, beta_mean, min_samples_leaf_mean, max_depth_mean) + forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension, + alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean, + leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, + global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) + forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean) } if (include_variance_forest) { - forest_model_variance <- createForestModel(forest_dataset_train, feature_types, num_trees_variance, nrow(X_train), alpha_variance, beta_variance, min_samples_leaf_variance, max_depth_variance) + forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1, + alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, + max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest, + global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance) } # Container of forest samples @@ -626,26 +636,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, - rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, + active_forest = active_forest_mean, rng = rng, model_config = forest_model_config_mean, keep_forest = keep_sample, gfr = T ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, + keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + if (include_mean_forest) forest_model_config_mean$update_global_error_variance(current_sigma2) + if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } if (has_rfx) { rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) @@ -665,6 +677,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_leaf) { leaf_scale_double <- leaf_scale_samples[forest_ind + 1] current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { @@ -675,7 +688,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) } - if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + if (sample_sigma_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + if (include_mean_forest) { + forest_model_config_mean$update_global_error_variance(current_sigma2) + } + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } + } } else if (has_prev_model) { if (include_mean_forest) { resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1) @@ -683,6 +704,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) { leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num] current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { @@ -698,6 +720,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] + if (include_mean_forest) { + forest_model_config_mean$update_global_error_variance(current_sigma2) + } + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } } } } else { @@ -707,6 +735,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) if (sample_sigma_leaf) { current_leaf_scale <- as.matrix(sigma_leaf_init) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { @@ -719,7 +748,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) } - if (sample_sigma_global) current_sigma2 <- sigma2_init + if (sample_sigma_global) { + current_sigma2 <- sigma2_init + if (include_mean_forest) { + forest_model_config_mean$update_global_error_variance(current_sigma2) + } + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } + } } for (i in (num_gfr+1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) @@ -748,26 +785,32 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, - rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, + active_forest = active_forest_mean, rng = rng, model_config = forest_model_config_mean, keep_forest = keep_sample, gfr = F ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, + keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + if (include_mean_forest) { + forest_model_config_mean$update_global_error_variance(current_sigma2) + } + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } if (has_rfx) { rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) diff --git a/R/bcf.R b/R/bcf.R index bd90201a..a14c7dba 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -687,8 +687,20 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + # Set mu and tau leaf models / dimensions + leaf_model_mu_forest <- 0 + leaf_dimension_mu_forest <- 1 + if (ncol(Z_train) > 1) { + leaf_model_tau_forest <- 2 + leaf_dimension_tau_forest <- ncol(Z_train) + } else { + leaf_model_tau_forest <- 1 + leaf_dimension_tau_forest <- 1 + } + # Set variance leaf model type (currently only one option) leaf_model_variance_forest <- 3 + leaf_dimension_variance_forest <- 1 # Random effects prior parameters if (has_rfx) { @@ -763,10 +775,26 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rng <- createCppRNG(random_seed) # Sampling data structures - forest_model_mu <- createForestModel(forest_dataset_train, feature_types, num_trees_mu, nrow(X_train), alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu) - forest_model_tau <- createForestModel(forest_dataset_train, feature_types, num_trees_tau, nrow(X_train), alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau) + forest_model_config_mu <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mu, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest, + alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu, + leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu, + global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) + forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest, + alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau, + leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau, + global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) + forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu) + forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau) if (include_variance_forest) { - forest_model_variance <- createForestModel(forest_dataset_train, feature_types, num_trees_variance, nrow(X_train), alpha_variance, beta_variance, min_samples_leaf_variance, max_depth_variance) + forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_variance, + leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance, + min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance, + leaf_model_type=leaf_model_variance_forest, global_error_variance=current_sigma2, + cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance) } # Container of forest samples @@ -810,26 +838,28 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, - rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, + active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, keep_forest = keep_sample, gfr = T ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, - rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, + active_forest = active_forest_tau, rng = rng, model_config = forest_model_config_tau, keep_forest = keep_sample, gfr = T ) # Sample coding parameters (if requested) @@ -872,19 +902,23 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, + keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample random effects parameters (if requested) @@ -907,10 +941,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- leaf_scale_mu_samples[forest_ind + 1] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- leaf_scale_tau_samples[forest_ind + 1] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (include_variance_forest) { resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) @@ -931,7 +967,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + if (sample_sigma_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } + } } else if (has_prev_model) { resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) @@ -944,10 +987,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (sample_sigma_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { leaf_scale_mu_double <- previous_leaf_var_mu_samples[previous_model_warmstart_sample_num] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { leaf_scale_tau_double <- previous_leaf_var_tau_samples[previous_model_warmstart_sample_num] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (adaptive_coding) { if (!is.null(previous_b_1_samples)) { @@ -974,6 +1019,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] } + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } } } else { resetActiveForest(active_forest_mu) @@ -984,9 +1034,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) if (sample_sigma_leaf_mu) { current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau) { current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (include_variance_forest) { resetActiveForest(active_forest_variance) @@ -1009,7 +1061,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) current_sigma2 <- sigma2_init + if (sample_sigma_global) { + current_sigma2 <- sigma2_init + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) { + forest_model_config_variance$update_global_error_variance(current_sigma2) + } + } } for (i in (num_gfr+1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) @@ -1038,26 +1097,28 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, - rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, + active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, keep_forest = keep_sample, gfr = F ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, - rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, + active_forest = active_forest_tau, rng = rng, model_config = forest_model_config_tau, keep_forest = keep_sample, gfr = F ) # Sample coding parameters (if requested) @@ -1100,19 +1161,23 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, + keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) + if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } # Sample random effects parameters (if requested) diff --git a/R/config.R b/R/config.R index 4cd6f8aa..edd1830d 100644 --- a/R/config.R +++ b/R/config.R @@ -1,104 +1,357 @@ -#' #' Dataset used to get / set parameters and other model configuration options -#' #' for the "low-level" stochtree interface -#' #' -#' #' @description -#' #' The "low-level" stochtree interface enables a high degreee of sampler -#' #' customization, in which users employ R wrappers around C++ objects -#' #' like ForestDataset, Outcome, CppRng, and ForestModel to run the -#' #' Gibbs sampler of a BART model with custom modifications. -#' #' ModelConfig allows users to specify / query the parameters of a -#' #' tree model they wish to run. -#' -#' ModelConfig <- R6::R6Class( -#' classname = "ModelConfig", -#' cloneable = FALSE, -#' public = list( -#' -#' #' @field data_ptr External pointer to a C++ ModelConfig class -#' data_ptr = NULL, -#' -#' #' @description -#' #' Create a new ForestDataset object. -#' #' @param covariates Matrix of covariates -#' #' @param basis (Optional) Matrix of bases used to define a leaf regression -#' #' @param variance_weights (Optional) Vector of observation-specific variance weights -#' #' @return A new `ForestDataset` object. -#' initialize = function(covariates, basis=NULL, variance_weights=NULL) { -#' self$data_ptr <- create_forest_dataset_cpp() -#' forest_dataset_add_covariates_cpp(self$data_ptr, covariates) -#' if (!is.null(basis)) { -#' forest_dataset_add_basis_cpp(self$data_ptr, basis) -#' } -#' if (!is.null(variance_weights)) { -#' forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) -#' } -#' }, -#' -#' #' @description -#' #' Update basis matrix in a dataset -#' #' @param basis Updated matrix of bases used to define a leaf regression -#' update_basis = function(basis) { -#' stopifnot(self$has_basis()) -#' forest_dataset_update_basis_cpp(self$data_ptr, basis) -#' }, -#' -#' #' @description -#' #' Return number of observations in a `ForestDataset` object -#' #' @return Observation count -#' num_observations = function() { -#' return(dataset_num_rows_cpp(self$data_ptr)) -#' }, -#' -#' #' @description -#' #' Return number of covariates in a `ForestDataset` object -#' #' @return Covariate count -#' num_covariates = function() { -#' return(dataset_num_covariates_cpp(self$data_ptr)) -#' }, -#' -#' #' @description -#' #' Return number of bases in a `ForestDataset` object -#' #' @return Basis count -#' num_basis = function() { -#' return(dataset_num_basis_cpp(self$data_ptr)) -#' }, -#' -#' #' @description -#' #' Whether or not a dataset has a basis matrix -#' #' @return True if basis matrix is loaded, false otherwise -#' has_basis = function() { -#' return(dataset_has_basis_cpp(self$data_ptr)) -#' }, -#' -#' #' @description -#' #' Whether or not a dataset has variance weights -#' #' @return True if variance weights are loaded, false otherwise -#' has_variance_weights = function() { -#' return(dataset_has_variance_weights_cpp(self$data_ptr)) -#' } -#' ), -#' private = list( -#' feature_types = NULL, -#' num_trees = NULL, -#' num_observations = NULL, -#' alpha = NULL, -#' beta = NULL, -#' min_samples_leaf = NULL, -#' max_depth = NULL, -#' ) -#' ) -#' -#' #' Create an model config object -#' #' -#' #' @return `ModelConfig` object -#' #' @export -#' #' -#' #' @examples -#' #' X <- matrix(runif(10*100), ncol = 10) -#' #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) -#' #' config <- createModelConfig(y) -#' createModelConfig <- function(){ -#' return(invisible(( -#' createModelConfig$new() -#' ))) -#' } +#' Dataset used to get / set parameters and other model configuration options +#' for the "low-level" stochtree interface +#' +#' @description +#' The "low-level" stochtree interface enables a high degreee of sampler +#' customization, in which users employ R wrappers around C++ objects +#' like ForestDataset, Outcome, CppRng, and ForestModel to run the +#' Gibbs sampler of a BART model with custom modifications. +#' ForestModelConfig allows users to specify / query the parameters of a +#' tree model they wish to run. + +ForestModelConfig <- R6::R6Class( + classname = "ForestModelConfig", + cloneable = FALSE, + public = list( + + #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + feature_types = NULL, + + #' @field num_trees Number of trees in the forest being sampled + num_trees = NULL, + + #' @field num_features Number of features in training dataset + num_features = NULL, + + #' @field num_observations Number of observations in training dataset + num_observations = NULL, + + #' @field leaf_dimension Dimension of the leaf model + leaf_dimension = NULL, + + #' @field alpha Root node split probability in tree prior + alpha = NULL, + + #' @field beta Depth prior penalty in tree prior + beta = NULL, + + #' @field min_samples_leaf Minimum number of samples in a tree leaf + min_samples_leaf = NULL, + + #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. + max_depth = NULL, + + #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) + leaf_model_type = NULL, + + #' @field leaf_model_scale Scale parameter used in Gaussian leaf models + leaf_model_scale = NULL, + + #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + variable_weights = NULL, + + #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_shape = NULL, + + #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_scale = NULL, + + #' @field global_error_variance Global error variance parameter + global_error_variance = NULL, + + #' @field cutpoint_grid_size Number of unique cutpoints to consider + cutpoint_grid_size = NULL, + + #' Create a new ForestModelConfig object. + #' + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param num_trees Number of trees in the forest being sampled + #' @param num_features Number of features in training dataset + #' @param num_observations Number of observations in training dataset + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + #' @param leaf_dimension Dimension of the leaf model (default: `1`) + #' @param alpha Root node split probability in tree prior (default: `0.95`) + #' @param beta Depth prior penalty in tree prior (default: `2.0`) + #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) + #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. + #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. + #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param global_error_variance Global error variance parameter (default: `1.0`) + #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) + #' + #' @return A new ForestModelConfig object. + initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL, + num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, + alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, + leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, + variance_forest_scale = 1.0, global_error_variance = 1.0, cutpoint_grid_size = 100) { + if (is.null(feature_types)) { + if (is.null(num_features)) { + stop("Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object.") + } + warning("`feature_types` not provided, will be assumed to be numeric") + feature_types <- rep(0, num_features) + } else { + if (is.null(num_features)) { + num_features <- length(feature_types) + } + } + if (is.null(variable_weights)) { + warning("`variable_weights` not provided, will be assumed to be equal-weighted") + variable_weights <- rep(1/num_features, num_features) + } + if (num_features != length(feature_types)) { + stop("`feature_types` must have `num_features` total elements") + } + if (num_features != length(variable_weights)) { + stop("`variable_weights` must have `num_features` total elements") + } + self$feature_types <- feature_types + self$variable_weights <- variable_weights + self$num_trees <- num_trees + self$num_features <- num_features + self$num_observations <- num_observations + self$leaf_dimension <- leaf_dimension + self$alpha <- alpha + self$beta <- beta + self$min_samples_leaf <- min_samples_leaf + self$max_depth <- max_depth + self$variance_forest_shape <- variance_forest_shape + self$variance_forest_scale <- variance_forest_scale + self$global_error_variance <- global_error_variance + self$cutpoint_grid_size <- cutpoint_grid_size + + if (!(as.integer(leaf_model_type) == leaf_model_type)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + if ((leaf_model_type < 0) | (leaf_model_type > 3)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + } + } + self$leaf_model_type <- leaf_model_type + + if (is.null(leaf_model_scale)) { + self$leaf_model_scale <- diag(1/num_trees, leaf_dimension) + } else if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != leaf_dimension) { + stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop("`leaf_model_scale` must be positive, if provided as scalar") + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update feature types + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + update_feature_types = function(feature_types) { + stopifnot(length(feature_types) == self$num_features) + self$feature_types <- feature_types + }, + + #' @description + #' Update variable weights + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + update_variable_weights = function(variable_weights) { + stopifnot(length(variable_weights) == self$num_features) + self$variable_weights <- variable_weights + }, + + #' @description + #' Update root node split probability in tree prior + #' @param alpha Root node split probability in tree prior + update_alpha = function(alpha) { + self$alpha <- alpha + }, + + #' @description + #' Update depth prior penalty in tree prior + #' @param beta Depth prior penalty in tree prior + update_beta = function(beta) { + self$beta <- beta + }, + + #' @description + #' Update root node split probability in tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + update_min_samples_leaf = function(min_samples_leaf) { + self$min_samples_leaf <- min_samples_leaf + }, + + #' @description + #' Update root node split probability in tree prior + #' @param max_depth Maximum depth of any tree in the ensemble in the model + update_max_depth = function(max_depth) { + self$max_depth <- max_depth + }, + + #' @description + #' Update scale parameter used in Gaussian leaf models + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models + update_leaf_model_scale = function(leaf_model_scale) { + if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != self$leaf_dimension) { + stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop("`leaf_model_scale` must be positive, if provided as scalar") + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update shape parameter for IG leaf models + #' @param variance_forest_shape Shape parameter for IG leaf models + update_variance_forest_shape = function(variance_forest_shape) { + self$variance_forest_shape <- variance_forest_shape + }, + + #' @description + #' Update scale parameter for IG leaf models + #' @param variance_forest_scale Scale parameter for IG leaf models + update_variance_forest_scale = function(variance_forest_scale) { + self$variance_forest_scale <- variance_forest_scale + }, + + #' @description + #' Update global error variance parameter + #' @param global_error_variance Global error variance parameter + update_global_error_variance = function(global_error_variance) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Update number of unique cutpoints to consider + #' @param cutpoint_grid_size Number of unique cutpoints to consider + update_cutpoint_grid_size = function(cutpoint_grid_size) { + self$cutpoint_grid_size <- cutpoint_grid_size + }, + + #' @description + #' Query feature types for this ForestModelConfig object + #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + get_feature_types = function() { + return(self$feature_types) + }, + + #' @description + #' Query variable weights for this ForestModelConfig object + #' @returns Vector specifying sampling probability for all p covariates in ForestDataset + get_variable_weights = function() { + return(self$variable_weights) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Root node split probability in tree prior + get_alpha = function() { + return(self$alpha) + }, + + #' @description + #' Query depth prior penalty in tree prior for this ForestModelConfig object + #' @returns Depth prior penalty in tree prior + get_beta = function() { + return(self$beta) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Minimum number of samples in a tree leaf + get_min_samples_leaf = function() { + return(self$min_samples_leaf) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Maximum depth of any tree in the ensemble in the model + get_max_depth = function() { + return(self$max_depth) + }, + + #' @description + #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object + #' @returns Scale parameter used in Gaussian leaf models + get_leaf_model_scale = function() { + return(self$leaf_model_scale) + }, + + #' @description + #' Query shape parameter for IG leaf models for this ForestModelConfig object + #' @returns Shape parameter for IG leaf models + get_variance_forest_shape = function() { + return(self$variance_forest_shape) + }, + + #' @description + #' Query scale parameter for IG leaf models for this ForestModelConfig object + #' @returns Scale parameter for IG leaf models + get_variance_forest_scale = function() { + return(self$variance_forest_scale) + }, + + #' @description + #' Query global error variance parameter for this ForestModelConfig object + #' @returns Global error variance parameter + get_global_error_variance = function() { + return(self$global_error_variance) + }, + + #' @description + #' Query number of unique cutpoints to consider for this ForestModelConfig object + #' @returns Number of unique cutpoints to consider + get_cutpoint_grid_size = function() { + return(self$cutpoint_grid_size) + } + ) +) + +#' Create an model config object +#' +#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) +#' @param num_trees Number of trees in the forest being sampled +#' @param num_features Number of features in training dataset +#' @param num_observations Number of observations in training dataset +#' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset +#' @param leaf_dimension Dimension of the leaf model (default: `1`) +#' @param alpha Root node split probability in tree prior (default: `0.95`) +#' @param beta Depth prior penalty in tree prior (default: `2.0`) +#' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) +#' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. +#' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. +#' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. +#' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. +#' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. +#' @param global_error_variance Global error variance parameter (default: `1.0`) +#' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) +#' @return ForestModelConfig object +#' @export +#' +#' @examples +#' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) +createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL, + num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, + alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, + leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, + variance_forest_scale = 1.0, global_error_variance = 1.0, cutpoint_grid_size = 100){ + return(invisible(( + ForestModelConfig$new(feature_types, num_trees, num_features, num_observations, + variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, + max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, + variance_forest_scale, global_error_variance, cutpoint_grid_size) + ))) +} diff --git a/R/model.R b/R/model.R index ef3e981a..3a4db170 100644 --- a/R/model.R +++ b/R/model.R @@ -65,24 +65,24 @@ ForestModel <- R6::R6Class( #' @param forest_samples Container of forest samples #' @param active_forest "Active" forest updated by the sampler in each iteration #' @param rng Wrapper around C++ random number generator - #' @param feature_types Vector specifying the type of all p covariates in `forest_dataset` (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) - #' @param leaf_model_scale Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`) - #' @param variable_weights Vector specifying sampling probability for all p covariates in `forest_dataset` - #' @param a_forest Shape parameter on variance forest model (if applicable) - #' @param b_forest Scale parameter on variance forest model (if applicable) - #' @param global_scale Global variance parameter - #' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: `500`, currently only used when `GFR = TRUE`) - #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `T`. - #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `T`. - #' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: `F`. - sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, rng, feature_types, - leaf_model_int, leaf_model_scale, variable_weights, - a_forest, b_forest, global_scale, cutpoint_grid_size = 500, - keep_forest = T, gfr = T) { + #' @param model_config ModelConfig object containing forest model parameters and settings + #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. + #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. + sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, + rng, model_config, keep_forest = T, gfr = T) { if (active_forest$is_empty()) { stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") } + + # Unpack parameters from model config object + feature_types <- model_config$feature_types + leaf_model_int <- model_config$leaf_model_type + leaf_model_scale <- model_config$leaf_model_scale + variable_weights <- model_config$variable_weights + a_forest <- model_config$variance_forest_shape + b_forest <- model_config$variance_forest_scale + global_scale <- model_config$global_error_variance + cutpoint_grid_size <- model_config$cutpoint_grid_size if (gfr) { sample_gfr_one_iteration_cpp( @@ -186,14 +186,8 @@ createCppRNG <- function(random_seed = -1){ #' Create a forest model object #' -#' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures -#' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) -#' @param num_trees Number of trees in the forest being sampled -#' @param n Number of observations in `forest_dataset` -#' @param alpha Root node split probability in tree prior -#' @param beta Depth prior penalty in tree prior -#' @param min_samples_leaf Minimum number of samples in a tree leaf -#' @param max_depth Maximum depth of any tree in the ensemble in the mean model. Setting to ``-1`` does not enforce any depth limits on trees. +#' @param forest_dataset ForestDataset object, used to initialize forest sampling data structures +#' @param model_config ModelConfig object containing forest model parameters and settings #' #' @return `ForestModel` object #' @export @@ -209,10 +203,15 @@ createCppRNG <- function(random_seed = -1){ #' feature_types <- as.integer(rep(0, p)) #' X <- matrix(runif(n*p), ncol = p) #' forest_dataset <- createForestDataset(X) -#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) -createForestModel <- function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) { +#' model_config <- createModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, +#' num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, leaf_model_type=1) +#' forest_model <- createForestModel(forest_dataset, model_config) +createForestModel <- function(forest_dataset, model_config) { return(invisible(( - ForestModel$new(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) + ForestModel$new(forest_dataset, model_config$feature_types, model_config$num_trees, + model_config$num_observations, model_config$alpha, model_config$beta, + model_config$min_samples_leaf, model_config$max_depth) ))) } diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 5e0ecdae..f9abb045 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -85,14 +85,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) forest_samples, active_forest, rng, - feature_types, - leaf_model_int, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - cutpoint_grid_size = 500, + model_config, keep_forest = T, gfr = T )}\if{html}{\out{}} @@ -111,27 +104,11 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{rng}}{Wrapper around C++ random number generator} -\item{\code{feature_types}}{Vector specifying the type of all p covariates in \code{forest_dataset} (0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} - -\item{\code{leaf_model_int}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)} - -\item{\code{leaf_model_scale}}{Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2})} - -\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in \code{forest_dataset}} - -\item{\code{a_forest}}{Shape parameter on variance forest model (if applicable)} - -\item{\code{b_forest}}{Scale parameter on variance forest model (if applicable)} - -\item{\code{global_scale}}{Global variance parameter} - -\item{\code{cutpoint_grid_size}}{(Optional) Number of unique cutpoints to consider (default: \code{500}, currently only used when \code{GFR = TRUE})} - -\item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{T}.} +\item{\code{model_config}}{ModelConfig object containing forest model parameters and settings} -\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: \code{T}.} +\item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{TRUE}.} -\item{\code{pre_initialized}}{(Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: \code{F}.} +\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: \code{TRUE}.} } \if{html}{\out{}} } diff --git a/man/ForestModelConfig.Rd b/man/ForestModelConfig.Rd new file mode 100644 index 00000000..0c843d38 --- /dev/null +++ b/man/ForestModelConfig.Rd @@ -0,0 +1,467 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{ForestModelConfig} +\alias{ForestModelConfig} +\title{Dataset used to get / set parameters and other model configuration options +for the "low-level" stochtree interface} +\value{ +Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + +Vector specifying sampling probability for all p covariates in ForestDataset + +Root node split probability in tree prior + +Depth prior penalty in tree prior + +Minimum number of samples in a tree leaf + +Maximum depth of any tree in the ensemble in the model + +Scale parameter used in Gaussian leaf models + +Shape parameter for IG leaf models + +Scale parameter for IG leaf models + +Global error variance parameter + +Number of unique cutpoints to consider +} +\description{ +The "low-level" stochtree interface enables a high degreee of sampler +customization, in which users employ R wrappers around C++ objects +like ForestDataset, Outcome, CppRng, and ForestModel to run the +Gibbs sampler of a BART model with custom modifications. +ForestModelConfig allows users to specify / query the parameters of a +tree model they wish to run. +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{\code{num_trees}}{Number of trees in the forest being sampled} + +\item{\code{num_features}}{Number of features in training dataset} + +\item{\code{num_observations}}{Number of observations in training dataset} + +\item{\code{leaf_dimension}}{Dimension of the leaf model} + +\item{\code{alpha}}{Root node split probability in tree prior} + +\item{\code{beta}}{Depth prior penalty in tree prior} + +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} + +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees.} + +\item{\code{leaf_model_type}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)} + +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models} + +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3})} + +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3})} + +\item{\code{global_error_variance}}{Global error variance parameter} + +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider +Create a new ForestModelConfig object.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-ForestModelConfig-new}{\code{ForestModelConfig$new()}} +\item \href{#method-ForestModelConfig-update_feature_types}{\code{ForestModelConfig$update_feature_types()}} +\item \href{#method-ForestModelConfig-update_variable_weights}{\code{ForestModelConfig$update_variable_weights()}} +\item \href{#method-ForestModelConfig-update_alpha}{\code{ForestModelConfig$update_alpha()}} +\item \href{#method-ForestModelConfig-update_beta}{\code{ForestModelConfig$update_beta()}} +\item \href{#method-ForestModelConfig-update_min_samples_leaf}{\code{ForestModelConfig$update_min_samples_leaf()}} +\item \href{#method-ForestModelConfig-update_max_depth}{\code{ForestModelConfig$update_max_depth()}} +\item \href{#method-ForestModelConfig-update_leaf_model_scale}{\code{ForestModelConfig$update_leaf_model_scale()}} +\item \href{#method-ForestModelConfig-update_variance_forest_shape}{\code{ForestModelConfig$update_variance_forest_shape()}} +\item \href{#method-ForestModelConfig-update_variance_forest_scale}{\code{ForestModelConfig$update_variance_forest_scale()}} +\item \href{#method-ForestModelConfig-update_global_error_variance}{\code{ForestModelConfig$update_global_error_variance()}} +\item \href{#method-ForestModelConfig-update_cutpoint_grid_size}{\code{ForestModelConfig$update_cutpoint_grid_size()}} +\item \href{#method-ForestModelConfig-get_feature_types}{\code{ForestModelConfig$get_feature_types()}} +\item \href{#method-ForestModelConfig-get_variable_weights}{\code{ForestModelConfig$get_variable_weights()}} +\item \href{#method-ForestModelConfig-get_alpha}{\code{ForestModelConfig$get_alpha()}} +\item \href{#method-ForestModelConfig-get_beta}{\code{ForestModelConfig$get_beta()}} +\item \href{#method-ForestModelConfig-get_min_samples_leaf}{\code{ForestModelConfig$get_min_samples_leaf()}} +\item \href{#method-ForestModelConfig-get_max_depth}{\code{ForestModelConfig$get_max_depth()}} +\item \href{#method-ForestModelConfig-get_leaf_model_scale}{\code{ForestModelConfig$get_leaf_model_scale()}} +\item \href{#method-ForestModelConfig-get_variance_forest_shape}{\code{ForestModelConfig$get_variance_forest_shape()}} +\item \href{#method-ForestModelConfig-get_variance_forest_scale}{\code{ForestModelConfig$get_variance_forest_scale()}} +\item \href{#method-ForestModelConfig-get_global_error_variance}{\code{ForestModelConfig$get_global_error_variance()}} +\item \href{#method-ForestModelConfig-get_cutpoint_grid_size}{\code{ForestModelConfig$get_cutpoint_grid_size()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-new}{}}} +\subsection{Method \code{new()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$new( + feature_types = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1, + variance_forest_scale = 1, + global_error_variance = 1, + cutpoint_grid_size = 100 +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{\code{num_trees}}{Number of trees in the forest being sampled} + +\item{\code{num_features}}{Number of features in training dataset} + +\item{\code{num_observations}}{Number of observations in training dataset} + +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{\code{leaf_dimension}}{Dimension of the leaf model (default: \code{1})} + +\item{\code{alpha}}{Root node split probability in tree prior (default: \code{0.95})} + +\item{\code{beta}}{Depth prior penalty in tree prior (default: \code{2.0})} + +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf (default: \code{5})} + +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees. Default: \code{-1}.} + +\item{\code{leaf_model_type}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: \code{0}.} + +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2}). Calibrated internally as \code{1/num_trees}, propagated along diagonal if needed for multivariate leaf models.} + +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{\code{global_error_variance}}{Global error variance parameter (default: \code{1.0})} + +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider (default: \code{100})} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new ForestModelConfig object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_feature_types}{}}} +\subsection{Method \code{update_feature_types()}}{ +Update feature types +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_feature_types(feature_types)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variable_weights}{}}} +\subsection{Method \code{update_variable_weights()}}{ +Update variable weights +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variable_weights(variable_weights)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_alpha}{}}} +\subsection{Method \code{update_alpha()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_alpha(alpha)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{alpha}}{Root node split probability in tree prior} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_beta}{}}} +\subsection{Method \code{update_beta()}}{ +Update depth prior penalty in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_beta(beta)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{beta}}{Depth prior penalty in tree prior} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_min_samples_leaf}{}}} +\subsection{Method \code{update_min_samples_leaf()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_min_samples_leaf(min_samples_leaf)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_max_depth}{}}} +\subsection{Method \code{update_max_depth()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_max_depth(max_depth)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_leaf_model_scale}{}}} +\subsection{Method \code{update_leaf_model_scale()}}{ +Update scale parameter used in Gaussian leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_leaf_model_scale(leaf_model_scale)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variance_forest_shape}{}}} +\subsection{Method \code{update_variance_forest_shape()}}{ +Update shape parameter for IG leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variance_forest_shape(variance_forest_shape)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variance_forest_scale}{}}} +\subsection{Method \code{update_variance_forest_scale()}}{ +Update scale parameter for IG leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variance_forest_scale(variance_forest_scale)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_global_error_variance}{}}} +\subsection{Method \code{update_global_error_variance()}}{ +Update global error variance parameter +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_global_error_variance(global_error_variance)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_cutpoint_grid_size}{}}} +\subsection{Method \code{update_cutpoint_grid_size()}}{ +Update number of unique cutpoints to consider +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_cutpoint_grid_size(cutpoint_grid_size)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_feature_types}{}}} +\subsection{Method \code{get_feature_types()}}{ +Query feature types for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_feature_types()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variable_weights}{}}} +\subsection{Method \code{get_variable_weights()}}{ +Query variable weights for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variable_weights()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_alpha}{}}} +\subsection{Method \code{get_alpha()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_alpha()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_beta}{}}} +\subsection{Method \code{get_beta()}}{ +Query depth prior penalty in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_beta()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_min_samples_leaf}{}}} +\subsection{Method \code{get_min_samples_leaf()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_min_samples_leaf()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_max_depth}{}}} +\subsection{Method \code{get_max_depth()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_max_depth()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_leaf_model_scale}{}}} +\subsection{Method \code{get_leaf_model_scale()}}{ +Query scale parameter used in Gaussian leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_leaf_model_scale()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variance_forest_shape}{}}} +\subsection{Method \code{get_variance_forest_shape()}}{ +Query shape parameter for IG leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variance_forest_shape()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variance_forest_scale}{}}} +\subsection{Method \code{get_variance_forest_scale()}}{ +Query scale parameter for IG leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variance_forest_scale()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_global_error_variance}{}}} +\subsection{Method \code{get_global_error_variance()}}{ +Query global error variance parameter for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_global_error_variance()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_cutpoint_grid_size}{}}} +\subsection{Method \code{get_cutpoint_grid_size()}}{ +Query number of unique cutpoints to consider for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_cutpoint_grid_size()}\if{html}{\out{
}} +} + +} +} diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index 05263bbb..9cf02945 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -4,33 +4,12 @@ \alias{createForestModel} \title{Create a forest model object} \usage{ -createForestModel( - forest_dataset, - feature_types, - num_trees, - n, - alpha, - beta, - min_samples_leaf, - max_depth -) +createForestModel(forest_dataset, model_config) } \arguments{ -\item{forest_dataset}{\code{ForestDataset} object, used to initialize forest sampling data structures} +\item{forest_dataset}{ForestDataset object, used to initialize forest sampling data structures} -\item{feature_types}{Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} - -\item{num_trees}{Number of trees in the forest being sampled} - -\item{n}{Number of observations in \code{forest_dataset}} - -\item{alpha}{Root node split probability in tree prior} - -\item{beta}{Depth prior penalty in tree prior} - -\item{min_samples_leaf}{Minimum number of samples in a tree leaf} - -\item{max_depth}{Maximum depth of any tree in the ensemble in the mean model. Setting to \code{-1} does not enforce any depth limits on trees.} +\item{model_config}{ModelConfig object containing forest model parameters and settings} } \value{ \code{ForestModel} object @@ -49,5 +28,8 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +model_config <- createModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=1) +forest_model <- createForestModel(forest_dataset, model_config) } diff --git a/man/createForestModelConfig.Rd b/man/createForestModelConfig.Rd new file mode 100644 index 00000000..07173606 --- /dev/null +++ b/man/createForestModelConfig.Rd @@ -0,0 +1,67 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{createForestModelConfig} +\alias{createForestModelConfig} +\title{Create an model config object} +\usage{ +createForestModelConfig( + feature_types = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1, + variance_forest_scale = 1, + global_error_variance = 1, + cutpoint_grid_size = 100 +) +} +\arguments{ +\item{feature_types}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{num_trees}{Number of trees in the forest being sampled} + +\item{num_features}{Number of features in training dataset} + +\item{num_observations}{Number of observations in training dataset} + +\item{variable_weights}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{leaf_dimension}{Dimension of the leaf model (default: \code{1})} + +\item{alpha}{Root node split probability in tree prior (default: \code{0.95})} + +\item{beta}{Depth prior penalty in tree prior (default: \code{2.0})} + +\item{min_samples_leaf}{Minimum number of samples in a tree leaf (default: \code{5})} + +\item{max_depth}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees. Default: \code{-1}.} + +\item{leaf_model_type}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: \code{0}.} + +\item{leaf_model_scale}{Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2}). Calibrated internally as \code{1/num_trees}, propagated along diagonal if needed for multivariate leaf models.} + +\item{variance_forest_shape}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{variance_forest_scale}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{global_error_variance}{Global error variance parameter (default: \code{1.0})} + +\item{cutpoint_grid_size}{Number of unique cutpoints to consider (default: \code{100})} +} +\value{ +ForestModelConfig object +} +\description{ +Create an model config object +} +\examples{ +config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) +} From 7d5b111f896d855dfc5be5e62a6c977a74e36965 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 00:20:11 -0600 Subject: [PATCH 05/21] Updated vignettes and docs --- R/forest.R | 10 +- _pkgdown.yml | 2 + man/resetForestModel.Rd | 10 +- vignettes/CausalInference.Rmd | 36 ++--- vignettes/CustomSamplingRoutine.Rmd | 232 +++++++++++++++++----------- 5 files changed, 177 insertions(+), 113 deletions(-) diff --git a/R/forest.R b/R/forest.R index e93957c2..4a6596d5 100644 --- a/R/forest.R +++ b/R/forest.R @@ -888,15 +888,17 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) #' outcome <- createOutcome(y) #' rng <- createCppRNG(1234) -#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, +#' alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, +#' variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, +#' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) +#' forest_model <- createForestModel(forest_dataset, fforest_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( #' forest_dataset, outcome, forest_samples, active_forest, -#' rng, feature_types, leaf_model, leaf_scale, variable_weights, -#' a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, -#' gfr = FALSE +#' rng, forest_model_config, keep_forest = TRUE, gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/_pkgdown.yml b/_pkgdown.yml index f0a68688..8ed4786d 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -79,6 +79,8 @@ reference: - createForestModel - ForestSamples - createForestSamples + - ForestModelConfig + - createForestModelConfig - CppRNG - createCppRNG - calibrateInverseGammaErrorVariance diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index 5ed766f7..f0fe4255 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -47,15 +47,17 @@ forest_dataset <- createForestDataset(X) y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) -forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, + alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, + variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) +forest_model <- createForestModel(forest_dataset, fforest_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, - rng, feature_types, leaf_model, leaf_scale, variable_weights, - a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, - gfr = FALSE + rng, forest_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index e7a8ff61..093d0260 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -110,7 +110,7 @@ initialization samples (@krantsevich2023stochastic). This is the default in ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -159,8 +159,8 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -328,7 +328,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -497,7 +497,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -566,7 +566,7 @@ X_4 &\sim N\left(X_2,1\right)\\ We draw from the DGP defined above ```{r data_4} -n <- 1000 +n <- 500 x1 <- rnorm(n) x2 <- rnorm(n) x3 <- rnorm(n) @@ -664,7 +664,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -773,9 +773,9 @@ rfx_term_train <- rfx_term[train_inds] Here we simulate only from the "warm-start" model (running root-MCMC BART with random effects is simply a matter of modifying the below code snippet by setting `num_gfr <- 0` and `num_mcmc` > 0). ```{r} -num_gfr <- 100 +num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 500 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -891,8 +891,8 @@ Here we simulate from the model with the original MCMC sampler, using all of the ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -953,8 +953,8 @@ Here we simulate from the model with the original MCMC sampler, using only covar ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1016,7 +1016,7 @@ Here we simulate from the model with the warm-start sampler, using all of the co ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1078,7 +1078,7 @@ Here we simulate from the model with the warm-start sampler, using only covariat ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1206,7 +1206,7 @@ initialization samples (@krantsevich2023stochastic). This is the default in ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1255,8 +1255,8 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index e2c03fea..75e33c36 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -101,10 +101,11 @@ beta <- 1.25 min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 -cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -121,9 +122,11 @@ Initialize R-level access to the C++ classes needed to sample our model if (leaf_regression) { forest_dataset <- createForestDataset(X, W) outcome_model_type <- 1 + leaf_dimension <- p_W } else { forest_dataset <- createForestDataset(X) outcome_model_type <- 0 + leaf_dimension <- 1 } outcome <- createOutcome(resid) @@ -131,9 +134,14 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size +) +forest_model <- createForestModel(forest_dataset, forest_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -167,21 +175,23 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) } ``` @@ -192,21 +202,23 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) } ``` @@ -284,9 +296,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -324,9 +337,14 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size +) +forest_model <- createForestModel(forest_dataset, forest_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -384,25 +402,27 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, - TRUE, global_var_samples[i+1], rng) + TRUE, current_sigma2, rng) } ``` @@ -413,25 +433,27 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, - TRUE, global_var_samples[i+1], rng) + TRUE, current_sigma2, rng) } ``` @@ -523,9 +545,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -563,9 +586,14 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size +) +forest_model <- createForestModel(forest_dataset, forest_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -623,21 +651,23 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, @@ -652,21 +682,23 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, @@ -771,9 +803,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -801,9 +834,14 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha_bart, beta_bart, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size +) +forest_model <- createForestModel(forest_dataset, forest_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -858,15 +896,16 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, sigma2, cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) } ``` @@ -896,15 +935,16 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + forest_model_config$update_global_error_variance(current_sigma2) } ``` @@ -1151,14 +1191,22 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model_mu <- createForestModel( - forest_dataset_mu, feature_types_mu, num_trees_mu, nrow(X_mu), - alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu +forest_model_config_mu <- createForestModelConfig( + feature_types = feature_types_mu, num_trees = num_trees_mu, num_features = ncol(X_mu), + num_observations = nrow(X_mu), variable_weights = variable_weights_mu, leaf_dimension = 1, + alpha = alpha_mu, beta = beta_mu, min_samples_leaf = min_samples_leaf_mu, max_depth = max_depth_mu, + leaf_model_type = 0, leaf_model_scale = current_leaf_scale_mu, + global_error_variance = current_sigma2, cutpoint_grid_size = cutpoint_grid_size ) -forest_model_tau <- createForestModel( - forest_dataset_tau, feature_types_tau, num_trees_tau, nrow(X_tau), - alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau +forest_model_mu <- createForestModel(forest_dataset_mu, forest_model_config_mu) +forest_model_config_tau <- createForestModelConfig( + feature_types = feature_types_tau, num_trees = num_trees_tau, num_features = ncol(X_tau), + num_observations = nrow(X_tau), variable_weights = variable_weights_tau, leaf_dimension = 1, + alpha = alpha_tau, beta = beta_tau, min_samples_leaf = min_samples_leaf_tau, max_depth = max_depth_tau, + leaf_model_type = 1, leaf_model_scale = current_leaf_scale_tau, + global_error_variance = current_sigma2, cutpoint_grid_size = cutpoint_grid_size ) +forest_model_tau <- createForestModel(forest_dataset_tau, forest_model_config_tau) # Container of forest samples forest_samples_mu <- createForestSamples(num_trees_mu, 1, T) @@ -1183,21 +1231,21 @@ if (num_gfr > 0){ # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, - feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T + forest_model_config_mu, keep_forest = T, gfr = T ) # Sample variance parameters (if requested) - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset_mu, rng, nu, lambda ) - current_sigma2 <- global_var_samples[i] + global_var_samples[i] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, - feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T + forest_model_config_tau, keep_forest = T, gfr = T ) # Sample adaptive coding parameters @@ -1219,8 +1267,12 @@ if (num_gfr > 0){ b_1_samples[i] <- current_b_1 # Sample variance parameters (if requested) - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_tau, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_tau, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) } } ``` @@ -1232,23 +1284,25 @@ if (num_burnin + num_mcmc > 0) { for (i in (num_gfr+1):num_samples) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, + forest_model_config_mu, keep_forest = T, gfr = F ) - # Sample global variance parameter - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_mu, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + # Sample variance parameters (if requested) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_mu, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, + forest_model_config_tau, keep_forest = T, gfr = F ) - # Sample coding parameters + # Sample adaptive coding parameters mu_x_raw <- active_forest_mu$predict_raw(forest_dataset_mu) tau_x_raw <- active_forest_tau$predict_raw(forest_dataset_tau) s_tt0 <- sum(tau_x_raw*tau_x_raw*(Z==0)) @@ -1265,10 +1319,14 @@ if (num_burnin + num_mcmc > 0) { forest_model_tau$propagate_basis_update(forest_dataset_tau, outcome, active_forest_tau) b_0_samples[i] <- current_b_0 b_1_samples[i] <- current_b_1 - - # Sample global variance parameter - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_tau, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + + # Sample variance parameters (if requested) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_tau, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + forest_model_config_mu$update_global_error_variance(current_sigma2) + forest_model_config_tau$update_global_error_variance(current_sigma2) } } ``` From 9595947c585c30c6cabdbdaaaa8faaf6fd1c7ed5 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 00:27:11 -0600 Subject: [PATCH 06/21] Fixed typo --- R/forest.R | 2 +- man/resetForestModel.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/forest.R b/R/forest.R index 4a6596d5..cc519241 100644 --- a/R/forest.R +++ b/R/forest.R @@ -892,7 +892,7 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, #' variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, #' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) -#' forest_model <- createForestModel(forest_dataset, fforest_model_config) +#' forest_model <- createForestModel(forest_dataset, forest_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fe4255..9356be6c 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -51,7 +51,7 @@ forest_model_config <- createForestModelConfig(feature_types=feature_types, num_ alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) -forest_model <- createForestModel(forest_dataset, fforest_model_config) +forest_model <- createForestModel(forest_dataset, forest_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) From 7cc1963e4dcb5fb70bf742e807a7788908614130 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 02:09:03 -0600 Subject: [PATCH 07/21] Updated interface to use separate "global" and "forest" model configs --- NAMESPACE | 1 + R/bart.R | 54 ++++++----------- R/bcf.R | 50 ++++++++-------- R/config.R | 92 ++++++++++++++++++++--------- R/model.R | 41 +++++++------ _pkgdown.yml | 2 + man/ForestModel.Rd | 7 ++- man/ForestModelConfig.Rd | 42 +------------ man/GlobalModelConfig.Rd | 80 +++++++++++++++++++++++++ man/createForestModel.Rd | 15 +++-- man/createForestModelConfig.Rd | 7 +-- man/createGlobalModelConfig.Rd | 20 +++++++ vignettes/CustomSamplingRoutine.Rmd | 80 ++++++++++++------------- 13 files changed, 290 insertions(+), 201 deletions(-) create mode 100644 man/GlobalModelConfig.Rd create mode 100644 man/createGlobalModelConfig.Rd diff --git a/NAMESPACE b/NAMESPACE index 3ea7ccd7..a18eabe7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -30,6 +30,7 @@ export(createForestDataset) export(createForestModel) export(createForestModelConfig) export(createForestSamples) +export(createGlobalModelConfig) export(createOutcome) export(createPreprocessorFromJson) export(createPreprocessorFromJsonString) diff --git a/R/bart.R b/R/bart.R index 4b1352b9..83672210 100644 --- a/R/bart.R +++ b/R/bart.R @@ -540,21 +540,22 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Sampling data structures feature_types <- as.integer(feature_types) + global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) if (include_mean_forest) { forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train), num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension, alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean, leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, - global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) - forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean) + cutpoint_grid_size=cutpoint_grid_size) + forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config) } if (include_variance_forest) { forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1, alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest, - global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) - forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance) + cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) } # Container of forest samples @@ -637,21 +638,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, - active_forest = active_forest_mean, rng = rng, model_config = forest_model_config_mean, keep_forest = keep_sample, gfr = T + active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, - keep_forest = keep_sample, gfr = T + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - if (include_mean_forest) forest_model_config_mean$update_global_error_variance(current_sigma2) - if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) @@ -690,12 +691,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (sample_sigma_global) { current_sigma2 <- global_var_samples[forest_ind + 1] - if (include_mean_forest) { - forest_model_config_mean$update_global_error_variance(current_sigma2) - } - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } else if (has_prev_model) { if (include_mean_forest) { @@ -720,12 +716,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] - if (include_mean_forest) { - forest_model_config_mean$update_global_error_variance(current_sigma2) - } - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } } else { @@ -750,12 +741,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (sample_sigma_global) { current_sigma2 <- sigma2_init - if (include_mean_forest) { - forest_model_config_mean$update_global_error_variance(current_sigma2) - } - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } for (i in (num_gfr+1):num_samples) { @@ -786,25 +772,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, - active_forest = active_forest_mean, rng = rng, model_config = forest_model_config_mean, keep_forest = keep_sample, gfr = F + active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, - keep_forest = keep_sample, gfr = F + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - if (include_mean_forest) { - forest_model_config_mean$update_global_error_variance(current_sigma2) - } - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) diff --git a/R/bcf.R b/R/bcf.R index a14c7dba..fc34e1c4 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -775,26 +775,26 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rng <- createCppRNG(random_seed) # Sampling data structures + global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) forest_model_config_mu <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mu, num_features=ncol(X_train), num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest, alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu, leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu, - global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) + cutpoint_grid_size=cutpoint_grid_size) forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train), num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest, alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau, leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau, - global_error_variance=current_sigma2, cutpoint_grid_size=cutpoint_grid_size) - forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu) - forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau) + cutpoint_grid_size=cutpoint_grid_size) + forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu, global_model_config) + forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau, global_model_config) if (include_variance_forest) { forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance, - leaf_model_type=leaf_model_variance_forest, global_error_variance=current_sigma2, - cutpoint_grid_size=cutpoint_grid_size) - forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance) + leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) } # Container of forest samples @@ -839,15 +839,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, - active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, keep_forest = keep_sample, gfr = T + active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) @@ -859,7 +858,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, - active_forest = active_forest_tau, rng = rng, model_config = forest_model_config_tau, keep_forest = keep_sample, gfr = T + active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) # Sample coding parameters (if requested) @@ -903,16 +903,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, - keep_forest = keep_sample, gfr = T + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) @@ -1098,15 +1096,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, - active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, keep_forest = keep_sample, gfr = F + active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) @@ -1118,7 +1115,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, - active_forest = active_forest_tau, rng = rng, model_config = forest_model_config_tau, keep_forest = keep_sample, gfr = F + active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) # Sample coding parameters (if requested) @@ -1162,16 +1160,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (include_variance_forest) { forest_model_variance$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, model_config = forest_model_config_variance, - keep_forest = keep_sample, gfr = F + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) forest_model_config_variance$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) diff --git a/R/config.R b/R/config.R index edd1830d..d7b1c52b 100644 --- a/R/config.R +++ b/R/config.R @@ -1,5 +1,5 @@ -#' Dataset used to get / set parameters and other model configuration options -#' for the "low-level" stochtree interface +#' Object used to get / set parameters and other model configuration options +#' for a forest model in the "low-level" stochtree interface #' #' @description #' The "low-level" stochtree interface enables a high degreee of sampler @@ -7,7 +7,7 @@ #' like ForestDataset, Outcome, CppRng, and ForestModel to run the #' Gibbs sampler of a BART model with custom modifications. #' ForestModelConfig allows users to specify / query the parameters of a -#' tree model they wish to run. +#' forest model they wish to run. ForestModelConfig <- R6::R6Class( classname = "ForestModelConfig", @@ -56,9 +56,6 @@ ForestModelConfig <- R6::R6Class( #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) variance_forest_scale = NULL, - #' @field global_error_variance Global error variance parameter - global_error_variance = NULL, - #' @field cutpoint_grid_size Number of unique cutpoints to consider cutpoint_grid_size = NULL, @@ -78,7 +75,6 @@ ForestModelConfig <- R6::R6Class( #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. - #' @param global_error_variance Global error variance parameter (default: `1.0`) #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) #' #' @return A new ForestModelConfig object. @@ -86,7 +82,7 @@ ForestModelConfig <- R6::R6Class( num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, - variance_forest_scale = 1.0, global_error_variance = 1.0, cutpoint_grid_size = 100) { + variance_forest_scale = 1.0, cutpoint_grid_size = 100) { if (is.null(feature_types)) { if (is.null(num_features)) { stop("Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object.") @@ -120,7 +116,6 @@ ForestModelConfig <- R6::R6Class( self$max_depth <- max_depth self$variance_forest_shape <- variance_forest_shape self$variance_forest_scale <- variance_forest_scale - self$global_error_variance <- global_error_variance self$cutpoint_grid_size <- cutpoint_grid_size if (!(as.integer(leaf_model_type) == leaf_model_type)) { @@ -227,13 +222,6 @@ ForestModelConfig <- R6::R6Class( self$variance_forest_scale <- variance_forest_scale }, - #' @description - #' Update global error variance parameter - #' @param global_error_variance Global error variance parameter - update_global_error_variance = function(global_error_variance) { - self$global_error_variance <- global_error_variance - }, - #' @description #' Update number of unique cutpoints to consider #' @param cutpoint_grid_size Number of unique cutpoints to consider @@ -304,13 +292,6 @@ ForestModelConfig <- R6::R6Class( return(self$variance_forest_scale) }, - #' @description - #' Query global error variance parameter for this ForestModelConfig object - #' @returns Global error variance parameter - get_global_error_variance = function() { - return(self$global_error_variance) - }, - #' @description #' Query number of unique cutpoints to consider for this ForestModelConfig object #' @returns Number of unique cutpoints to consider @@ -320,7 +301,51 @@ ForestModelConfig <- R6::R6Class( ) ) -#' Create an model config object +#' Object used to get / set global parameters and other global model +#' configuration options in the "low-level" stochtree interface +#' +#' @description +#' The "low-level" stochtree interface enables a high degreee of sampler +#' customization, in which users employ R wrappers around C++ objects +#' like ForestDataset, Outcome, CppRng, and ForestModel to run the +#' Gibbs sampler of a BART model with custom modifications. +#' GlobalModelConfig allows users to specify / query the global parameters +#' of a model they wish to run. + +GlobalModelConfig <- R6::R6Class( + classname = "GlobalModelConfig", + cloneable = FALSE, + public = list( + + #' @field global_error_variance Global error variance parameter + global_error_variance = NULL, + + #' Create a new GlobalModelConfig object. + #' + #' @param global_error_variance Global error variance parameter (default: `1.0`) + #' + #' @return A new GlobalModelConfig object. + initialize = function(global_error_variance = 1.0) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Update global error variance parameter + #' @param global_error_variance Global error variance parameter + update_global_error_variance = function(global_error_variance) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Query global error variance parameter for this GlobalModelConfig object + #' @returns Global error variance parameter + get_global_error_variance = function() { + return(self$global_error_variance) + } + ) +) + +#' Create a forest model config object #' #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) #' @param num_trees Number of trees in the forest being sampled @@ -336,7 +361,6 @@ ForestModelConfig <- R6::R6Class( #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. -#' @param global_error_variance Global error variance parameter (default: `1.0`) #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) #' @return ForestModelConfig object #' @export @@ -347,11 +371,25 @@ createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_ num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, - variance_forest_scale = 1.0, global_error_variance = 1.0, cutpoint_grid_size = 100){ + variance_forest_scale = 1.0, cutpoint_grid_size = 100){ return(invisible(( ForestModelConfig$new(feature_types, num_trees, num_features, num_observations, variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, - variance_forest_scale, global_error_variance, cutpoint_grid_size) + variance_forest_scale, cutpoint_grid_size) + ))) +} + +#' Create a global model config object +#' +#' @param global_error_variance Global error variance parameter (default: `1.0`) +#' @return GlobalModelConfig object +#' @export +#' +#' @examples +#' config <- createGlobalModelConfig( = 100) +createGlobalModelConfig <- function(global_error_variance = 1.0){ + return(invisible(( + GlobalModelConfig$new(global_error_variance) ))) } diff --git a/R/model.R b/R/model.R index 3a4db170..7cde7274 100644 --- a/R/model.R +++ b/R/model.R @@ -65,24 +65,25 @@ ForestModel <- R6::R6Class( #' @param forest_samples Container of forest samples #' @param active_forest "Active" forest updated by the sampler in each iteration #' @param rng Wrapper around C++ random number generator - #' @param model_config ModelConfig object containing forest model parameters and settings + #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings + #' @param global_model_config GlobalModelConfig object containing global model parameters and settings #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, - rng, model_config, keep_forest = T, gfr = T) { + rng, forest_model_config, global_model_config, keep_forest = T, gfr = T) { if (active_forest$is_empty()) { stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") } # Unpack parameters from model config object - feature_types <- model_config$feature_types - leaf_model_int <- model_config$leaf_model_type - leaf_model_scale <- model_config$leaf_model_scale - variable_weights <- model_config$variable_weights - a_forest <- model_config$variance_forest_shape - b_forest <- model_config$variance_forest_scale - global_scale <- model_config$global_error_variance - cutpoint_grid_size <- model_config$cutpoint_grid_size + feature_types <- forest_model_config$feature_types + leaf_model_int <- forest_model_config$leaf_model_type + leaf_model_scale <- forest_model_config$leaf_model_scale + variable_weights <- forest_model_config$variable_weights + a_forest <- forest_model_config$variance_forest_shape + b_forest <- forest_model_config$variance_forest_scale + global_scale <- global_model_config$global_error_variance + cutpoint_grid_size <- forest_model_config$cutpoint_grid_size if (gfr) { sample_gfr_one_iteration_cpp( @@ -187,7 +188,8 @@ createCppRNG <- function(random_seed = -1){ #' Create a forest model object #' #' @param forest_dataset ForestDataset object, used to initialize forest sampling data structures -#' @param model_config ModelConfig object containing forest model parameters and settings +#' @param forest_model_config ForestModelConfig object containing forest model parameters and settings +#' @param global_model_config GlobalModelConfig object containing global model parameters and settings #' #' @return `ForestModel` object #' @export @@ -203,15 +205,16 @@ createCppRNG <- function(random_seed = -1){ #' feature_types <- as.integer(rep(0, p)) #' X <- matrix(runif(n*p), ncol = p) #' forest_dataset <- createForestDataset(X) -#' model_config <- createModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, -#' num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, -#' max_depth=max_depth, leaf_model_type=1) -#' forest_model <- createForestModel(forest_dataset, model_config) -createForestModel <- function(forest_dataset, model_config) { +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, +#' num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, leaf_model_type=1) +#' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) +#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) +createForestModel <- function(forest_dataset, forest_model_config, global_model_config) { return(invisible(( - ForestModel$new(forest_dataset, model_config$feature_types, model_config$num_trees, - model_config$num_observations, model_config$alpha, model_config$beta, - model_config$min_samples_leaf, model_config$max_depth) + ForestModel$new(forest_dataset, forest_model_config$feature_types, forest_model_config$num_trees, + forest_model_config$num_observations, forest_model_config$alpha, forest_model_config$beta, + forest_model_config$min_samples_leaf, forest_model_config$max_depth) ))) } diff --git a/_pkgdown.yml b/_pkgdown.yml index 8ed4786d..43b7c995 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -81,6 +81,8 @@ reference: - createForestSamples - ForestModelConfig - createForestModelConfig + - GlobalModelConfig + - createGlobalModelConfig - CppRNG - createCppRNG - calibrateInverseGammaErrorVariance diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index f9abb045..d317da72 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -85,7 +85,8 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) forest_samples, active_forest, rng, - model_config, + forest_model_config, + global_model_config, keep_forest = T, gfr = T )}\if{html}{\out{}} @@ -104,7 +105,9 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{rng}}{Wrapper around C++ random number generator} -\item{\code{model_config}}{ModelConfig object containing forest model parameters and settings} +\item{\code{forest_model_config}}{ForestModelConfig object containing forest model parameters and settings} + +\item{\code{global_model_config}}{GlobalModelConfig object containing global model parameters and settings} \item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{TRUE}.} diff --git a/man/ForestModelConfig.Rd b/man/ForestModelConfig.Rd index 0c843d38..e899c8b1 100644 --- a/man/ForestModelConfig.Rd +++ b/man/ForestModelConfig.Rd @@ -2,8 +2,8 @@ % Please edit documentation in R/config.R \name{ForestModelConfig} \alias{ForestModelConfig} -\title{Dataset used to get / set parameters and other model configuration options -for the "low-level" stochtree interface} +\title{Object used to get / set parameters and other model configuration options +for a forest model in the "low-level" stochtree interface} \value{ Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) @@ -23,8 +23,6 @@ Shape parameter for IG leaf models Scale parameter for IG leaf models -Global error variance parameter - Number of unique cutpoints to consider } \description{ @@ -33,7 +31,7 @@ customization, in which users employ R wrappers around C++ objects like ForestDataset, Outcome, CppRng, and ForestModel to run the Gibbs sampler of a BART model with custom modifications. ForestModelConfig allows users to specify / query the parameters of a -tree model they wish to run. +forest model they wish to run. } \section{Public fields}{ \if{html}{\out{
}} @@ -66,8 +64,6 @@ tree model they wish to run. \item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3})} -\item{\code{global_error_variance}}{Global error variance parameter} - \item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider Create a new ForestModelConfig object.} } @@ -86,7 +82,6 @@ Create a new ForestModelConfig object.} \item \href{#method-ForestModelConfig-update_leaf_model_scale}{\code{ForestModelConfig$update_leaf_model_scale()}} \item \href{#method-ForestModelConfig-update_variance_forest_shape}{\code{ForestModelConfig$update_variance_forest_shape()}} \item \href{#method-ForestModelConfig-update_variance_forest_scale}{\code{ForestModelConfig$update_variance_forest_scale()}} -\item \href{#method-ForestModelConfig-update_global_error_variance}{\code{ForestModelConfig$update_global_error_variance()}} \item \href{#method-ForestModelConfig-update_cutpoint_grid_size}{\code{ForestModelConfig$update_cutpoint_grid_size()}} \item \href{#method-ForestModelConfig-get_feature_types}{\code{ForestModelConfig$get_feature_types()}} \item \href{#method-ForestModelConfig-get_variable_weights}{\code{ForestModelConfig$get_variable_weights()}} @@ -97,7 +92,6 @@ Create a new ForestModelConfig object.} \item \href{#method-ForestModelConfig-get_leaf_model_scale}{\code{ForestModelConfig$get_leaf_model_scale()}} \item \href{#method-ForestModelConfig-get_variance_forest_shape}{\code{ForestModelConfig$get_variance_forest_shape()}} \item \href{#method-ForestModelConfig-get_variance_forest_scale}{\code{ForestModelConfig$get_variance_forest_scale()}} -\item \href{#method-ForestModelConfig-get_global_error_variance}{\code{ForestModelConfig$get_global_error_variance()}} \item \href{#method-ForestModelConfig-get_cutpoint_grid_size}{\code{ForestModelConfig$get_cutpoint_grid_size()}} } } @@ -121,7 +115,6 @@ Create a new ForestModelConfig object.} leaf_model_scale = NULL, variance_forest_shape = 1, variance_forest_scale = 1, - global_error_variance = 1, cutpoint_grid_size = 100 )}\if{html}{\out{
}} } @@ -157,8 +150,6 @@ Create a new ForestModelConfig object.} \item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} -\item{\code{global_error_variance}}{Global error variance parameter (default: \code{1.0})} - \item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider (default: \code{100})} } \if{html}{\out{}} @@ -321,23 +312,6 @@ Update scale parameter for IG leaf models } } \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_global_error_variance}{}}} -\subsection{Method \code{update_global_error_variance()}}{ -Update global error variance parameter -\subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ForestModelConfig$update_global_error_variance(global_error_variance)}\if{html}{\out{
}} -} - -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{global_error_variance}}{Global error variance parameter} -} -\if{html}{\out{
}} -} -} -\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestModelConfig-update_cutpoint_grid_size}{}}} \subsection{Method \code{update_cutpoint_grid_size()}}{ @@ -443,16 +417,6 @@ Query scale parameter for IG leaf models for this ForestModelConfig object \if{html}{\out{
}}\preformatted{ForestModelConfig$get_variance_forest_scale()}\if{html}{\out{
}} } -} -\if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_global_error_variance}{}}} -\subsection{Method \code{get_global_error_variance()}}{ -Query global error variance parameter for this ForestModelConfig object -\subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ForestModelConfig$get_global_error_variance()}\if{html}{\out{
}} -} - } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/GlobalModelConfig.Rd b/man/GlobalModelConfig.Rd new file mode 100644 index 00000000..fa28e635 --- /dev/null +++ b/man/GlobalModelConfig.Rd @@ -0,0 +1,80 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{GlobalModelConfig} +\alias{GlobalModelConfig} +\title{Object used to get / set global parameters and other global model +configuration options in the "low-level" stochtree interface} +\value{ +Global error variance parameter +} +\description{ +The "low-level" stochtree interface enables a high degreee of sampler +customization, in which users employ R wrappers around C++ objects +like ForestDataset, Outcome, CppRng, and ForestModel to run the +Gibbs sampler of a BART model with custom modifications. +GlobalModelConfig allows users to specify / query the global parameters +of a model they wish to run. +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter +Create a new GlobalModelConfig object.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-GlobalModelConfig-new}{\code{GlobalModelConfig$new()}} +\item \href{#method-GlobalModelConfig-update_global_error_variance}{\code{GlobalModelConfig$update_global_error_variance()}} +\item \href{#method-GlobalModelConfig-get_global_error_variance}{\code{GlobalModelConfig$get_global_error_variance()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-new}{}}} +\subsection{Method \code{new()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$new(global_error_variance = 1)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter (default: \code{1.0})} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new GlobalModelConfig object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-update_global_error_variance}{}}} +\subsection{Method \code{update_global_error_variance()}}{ +Update global error variance parameter +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$update_global_error_variance(global_error_variance)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-get_global_error_variance}{}}} +\subsection{Method \code{get_global_error_variance()}}{ +Query global error variance parameter for this GlobalModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$get_global_error_variance()}\if{html}{\out{
}} +} + +} +} diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index 9cf02945..836f627e 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -4,12 +4,14 @@ \alias{createForestModel} \title{Create a forest model object} \usage{ -createForestModel(forest_dataset, model_config) +createForestModel(forest_dataset, forest_model_config, global_model_config) } \arguments{ \item{forest_dataset}{ForestDataset object, used to initialize forest sampling data structures} -\item{model_config}{ModelConfig object containing forest model parameters and settings} +\item{forest_model_config}{ForestModelConfig object containing forest model parameters and settings} + +\item{global_model_config}{GlobalModelConfig object containing global model parameters and settings} } \value{ \code{ForestModel} object @@ -28,8 +30,9 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -model_config <- createModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, - max_depth=max_depth, leaf_model_type=1) -forest_model <- createForestModel(forest_dataset, model_config) +forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=1) +global_model_config <- createGlobalModelConfig(global_error_variance=1.0) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) } diff --git a/man/createForestModelConfig.Rd b/man/createForestModelConfig.Rd index 07173606..90de767c 100644 --- a/man/createForestModelConfig.Rd +++ b/man/createForestModelConfig.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/config.R \name{createForestModelConfig} \alias{createForestModelConfig} -\title{Create an model config object} +\title{Create a forest model config object} \usage{ createForestModelConfig( feature_types = NULL, @@ -19,7 +19,6 @@ createForestModelConfig( leaf_model_scale = NULL, variance_forest_shape = 1, variance_forest_scale = 1, - global_error_variance = 1, cutpoint_grid_size = 100 ) } @@ -52,15 +51,13 @@ createForestModelConfig( \item{variance_forest_scale}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} -\item{global_error_variance}{Global error variance parameter (default: \code{1.0})} - \item{cutpoint_grid_size}{Number of unique cutpoints to consider (default: \code{100})} } \value{ ForestModelConfig object } \description{ -Create an model config object +Create a forest model config object } \examples{ config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) diff --git a/man/createGlobalModelConfig.Rd b/man/createGlobalModelConfig.Rd new file mode 100644 index 00000000..7741a096 --- /dev/null +++ b/man/createGlobalModelConfig.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{createGlobalModelConfig} +\alias{createGlobalModelConfig} +\title{Create a global model config object} +\usage{ +createGlobalModelConfig(global_error_variance = 1) +} +\arguments{ +\item{global_error_variance}{Global error variance parameter (default: \code{1.0})} +} +\value{ +GlobalModelConfig object +} +\description{ +Create a global model config object +} +\examples{ +config <- createGlobalModelConfig( = 100) +} diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 75e33c36..6f086ef7 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -139,9 +139,10 @@ forest_model_config <- createForestModelConfig( num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, - global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size + cutpoint_grid_size = cutpoint_grid_size ) -forest_model <- createForestModel(forest_dataset, forest_model_config) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -176,7 +177,7 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = T + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -184,7 +185,7 @@ for (i in 1:num_warmstart) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -203,7 +204,7 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = F + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -211,7 +212,7 @@ for (i in (num_warmstart+1):num_samples) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -342,9 +343,10 @@ forest_model_config <- createForestModelConfig( num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, - global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size + cutpoint_grid_size = cutpoint_grid_size ) -forest_model <- createForestModel(forest_dataset, forest_model_config) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -403,7 +405,7 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = T + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -411,7 +413,7 @@ for (i in 1:num_warmstart) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -434,7 +436,7 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = F + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -442,7 +444,7 @@ for (i in (num_warmstart+1):num_samples) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -591,9 +593,10 @@ forest_model_config <- createForestModelConfig( num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, - global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size + cutpoint_grid_size = cutpoint_grid_size ) -forest_model <- createForestModel(forest_dataset, forest_model_config) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -652,7 +655,7 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = T + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -660,7 +663,7 @@ for (i in 1:num_warmstart) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -683,7 +686,7 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = F + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -691,7 +694,7 @@ for (i in (num_warmstart+1):num_samples) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( @@ -839,9 +842,10 @@ forest_model_config <- createForestModelConfig( num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, - global_error_variance = global_variance_init, cutpoint_grid_size = cutpoint_grid_size + cutpoint_grid_size = cutpoint_grid_size ) -forest_model <- createForestModel(forest_dataset, forest_model_config) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -897,7 +901,7 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = T + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter @@ -905,7 +909,7 @@ for (i in 1:num_warmstart) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } ``` @@ -936,7 +940,7 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, - forest_model_config, keep_forest = T, gfr = F + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter @@ -944,7 +948,7 @@ for (i in (num_warmstart+1):num_samples) { outcome, forest_dataset, rng, nu, lambda ) global_var_samples[i+1] <- current_sigma2 - forest_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } ``` @@ -1191,14 +1195,14 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures +global_model_config <- createGlobalModelConfig(global_error_variance = current_sigma2) forest_model_config_mu <- createForestModelConfig( feature_types = feature_types_mu, num_trees = num_trees_mu, num_features = ncol(X_mu), num_observations = nrow(X_mu), variable_weights = variable_weights_mu, leaf_dimension = 1, alpha = alpha_mu, beta = beta_mu, min_samples_leaf = min_samples_leaf_mu, max_depth = max_depth_mu, - leaf_model_type = 0, leaf_model_scale = current_leaf_scale_mu, - global_error_variance = current_sigma2, cutpoint_grid_size = cutpoint_grid_size + leaf_model_type = 0, leaf_model_scale = current_leaf_scale_mu, cutpoint_grid_size = cutpoint_grid_size ) -forest_model_mu <- createForestModel(forest_dataset_mu, forest_model_config_mu) +forest_model_mu <- createForestModel(forest_dataset_mu, forest_model_config_mu, global_model_config) forest_model_config_tau <- createForestModelConfig( feature_types = feature_types_tau, num_trees = num_trees_tau, num_features = ncol(X_tau), num_observations = nrow(X_tau), variable_weights = variable_weights_tau, leaf_dimension = 1, @@ -1206,7 +1210,7 @@ forest_model_config_tau <- createForestModelConfig( leaf_model_type = 1, leaf_model_scale = current_leaf_scale_tau, global_error_variance = current_sigma2, cutpoint_grid_size = cutpoint_grid_size ) -forest_model_tau <- createForestModel(forest_dataset_tau, forest_model_config_tau) +forest_model_tau <- createForestModel(forest_dataset_tau, forest_model_config_tau, global_model_config) # Container of forest samples forest_samples_mu <- createForestSamples(num_trees_mu, 1, T) @@ -1231,7 +1235,7 @@ if (num_gfr > 0){ # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, - forest_model_config_mu, keep_forest = T, gfr = T + forest_model_config_mu, global_model_config, keep_forest = T, gfr = T ) # Sample variance parameters (if requested) @@ -1239,13 +1243,12 @@ if (num_gfr > 0){ outcome, forest_dataset_mu, rng, nu, lambda ) global_var_samples[i] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, - forest_model_config_tau, keep_forest = T, gfr = T + forest_model_config_tau, global_model_config, keep_forest = T, gfr = T ) # Sample adaptive coding parameters @@ -1271,8 +1274,7 @@ if (num_gfr > 0){ outcome, forest_dataset_tau, rng, nu, lambda ) global_var_samples[i] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } } ``` @@ -1285,7 +1287,7 @@ if (num_burnin + num_mcmc > 0) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, - forest_model_config_mu, keep_forest = T, gfr = F + forest_model_config_mu, global_model_config, keep_forest = T, gfr = F ) # Sample variance parameters (if requested) @@ -1293,13 +1295,12 @@ if (num_burnin + num_mcmc > 0) { outcome, forest_dataset_mu, rng, nu, lambda ) global_var_samples[i] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, - forest_model_config_tau, keep_forest = T, gfr = F + forest_model_config_tau, global_model_config, keep_forest = T, gfr = F ) # Sample adaptive coding parameters @@ -1325,8 +1326,7 @@ if (num_burnin + num_mcmc > 0) { outcome, forest_dataset_tau, rng, nu, lambda ) global_var_samples[i] <- current_sigma2 - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance(current_sigma2) } } ``` From 86dd8d8eadcb8fcd9a8a4e97a6c9bc395981ebea Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 02:11:11 -0600 Subject: [PATCH 08/21] Fixed example code --- R/forest.R | 5 +++-- man/resetForestModel.Rd | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/R/forest.R b/R/forest.R index cc519241..5ae8e85b 100644 --- a/R/forest.R +++ b/R/forest.R @@ -888,11 +888,12 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) #' outcome <- createOutcome(y) #' rng <- createCppRNG(1234) +#' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) #' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, #' alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, #' variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, -#' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) -#' forest_model <- createForestModel(forest_dataset, forest_model_config) +#' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) +#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index 9356be6c..e300765f 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -47,11 +47,12 @@ forest_dataset <- createForestDataset(X) y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) +global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, leaf_model_scale=leaf_scale, global_error_variance=sigma2) -forest_model <- createForestModel(forest_dataset, forest_model_config) + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) From 95610df6e3c3b203a103c8b65025f5c08e11207a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 02:14:23 -0600 Subject: [PATCH 09/21] Fixed bug in code example --- R/forest.R | 2 +- man/resetForestModel.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/forest.R b/R/forest.R index 5ae8e85b..06491f7d 100644 --- a/R/forest.R +++ b/R/forest.R @@ -899,7 +899,7 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( #' forest_dataset, outcome, forest_samples, active_forest, -#' rng, forest_model_config, keep_forest = TRUE, gfr = FALSE +#' rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index e300765f..06e1721e 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -58,7 +58,7 @@ forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constan active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, keep_forest = TRUE, gfr = FALSE + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) From 0c7c0e19579835cbf25ba8dc5d80568e69ee9846 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 02:24:57 -0600 Subject: [PATCH 10/21] Updated BCF's use of the new interface --- R/bcf.R | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index fc34e1c4..1db9c24c 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -967,11 +967,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } if (sample_sigma_global) { current_sigma2 <- global_var_samples[forest_ind + 1] - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } else if (has_prev_model) { resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) @@ -1017,11 +1013,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] } - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } else { resetActiveForest(active_forest_mu) @@ -1061,11 +1053,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } if (sample_sigma_global) { current_sigma2 <- sigma2_init - forest_model_config_mu$update_global_error_variance(current_sigma2) - forest_model_config_tau$update_global_error_variance(current_sigma2) - if (include_variance_forest) { - forest_model_config_variance$update_global_error_variance(current_sigma2) - } + global_model_config$update_global_error_variance(current_sigma2) } } for (i in (num_gfr+1):num_samples) { @@ -1096,7 +1084,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, - active_forest = active_forest_mu, rng = rng, model_config = forest_model_config_mu, + active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) From 8ea864f972a11cd62081bc9265fe5ced8a755b41 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 02:42:16 -0600 Subject: [PATCH 11/21] Updated documentation and vignettes --- R/config.R | 2 +- man/createGlobalModelConfig.Rd | 2 +- vignettes/CustomSamplingRoutine.Rmd | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/config.R b/R/config.R index d7b1c52b..08693674 100644 --- a/R/config.R +++ b/R/config.R @@ -387,7 +387,7 @@ createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_ #' @export #' #' @examples -#' config <- createGlobalModelConfig( = 100) +#' config <- createGlobalModelConfig(global_error_variance = 100) createGlobalModelConfig <- function(global_error_variance = 1.0){ return(invisible(( GlobalModelConfig$new(global_error_variance) diff --git a/man/createGlobalModelConfig.Rd b/man/createGlobalModelConfig.Rd index 7741a096..59225789 100644 --- a/man/createGlobalModelConfig.Rd +++ b/man/createGlobalModelConfig.Rd @@ -16,5 +16,5 @@ GlobalModelConfig object Create a global model config object } \examples{ -config <- createGlobalModelConfig( = 100) +config <- createGlobalModelConfig(global_error_variance = 100) } diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 6f086ef7..22399325 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -1208,7 +1208,7 @@ forest_model_config_tau <- createForestModelConfig( num_observations = nrow(X_tau), variable_weights = variable_weights_tau, leaf_dimension = 1, alpha = alpha_tau, beta = beta_tau, min_samples_leaf = min_samples_leaf_tau, max_depth = max_depth_tau, leaf_model_type = 1, leaf_model_scale = current_leaf_scale_tau, - global_error_variance = current_sigma2, cutpoint_grid_size = cutpoint_grid_size + cutpoint_grid_size = cutpoint_grid_size ) forest_model_tau <- createForestModel(forest_dataset_tau, forest_model_config_tau, global_model_config) From 49885cf1182ea22da3f81898caf3a96be816243f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 13:02:58 -0600 Subject: [PATCH 12/21] Updated unit tests --- test/R/testthat/test-residual.R | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index 04165271..eef4f731 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -36,7 +36,14 @@ test_that("Residual updates correctly propagated after forest sampling step", { cpp_rng = createCppRNG(-1) # Create forest sampler and forest container - forest_model = createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) + global_model_config = createGlobalModelConfig(global_error_variance=current_sigma2) + forest_model_config = createForestModelConfig(feature_types=feature_types, num_trees=num_trees, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, + leaf_model_type=0, leaf_model_scale=current_leaf_scale, + variable_weights=variable_weights, variance_forest_shape=a_forest, + variance_forest_scale=b_forest, cutpoint_grid_size=cutpoint_grid_size) + forest_model = createForestModel(forest_dataset, forest_model_config, global_model_config) forest_samples = createForestSamples(num_trees, 1, F) active_forest = createForest(num_trees, 1, F) @@ -47,8 +54,7 @@ test_that("Residual updates correctly propagated after forest sampling step", { # Run the forest sampling algorithm for a single iteration forest_model$sample_one_iteration( forest_dataset, residual, forest_samples, active_forest, - cpp_rng, feature_types, 0, current_leaf_scale, variable_weights, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, pre_initialized = T + cpp_rng, forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Get the current residual after running the sampler From d3ae584b4cd38fe9564b6d76b5eef955907f110a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 4 Feb 2025 17:37:24 -0600 Subject: [PATCH 13/21] Add BCF unit tests and update BCF and BART warm start logic --- R/bart.R | 14 ++- R/bcf.R | 20 +++- test/R/testthat/test-bcf.R | 183 +++++++++++++++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 10 deletions(-) create mode 100644 test/R/testthat/test-bcf.R diff --git a/R/bart.R b/R/bart.R index 83672210..4952f96d 100644 --- a/R/bart.R +++ b/R/bart.R @@ -222,6 +222,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (previous_bart_model$model_params$has_rfx) { previous_rfx_samples <- previous_bart_model$rfx_samples } else previous_rfx_samples <- NULL + previous_model_num_samples <- previous_bart_model$model_params$num_samples + if (previous_model_warmstart_sample_num >= previous_model_num_samples) { + stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + } } else { previous_y_bar <- NULL previous_y_scale <- NULL @@ -230,6 +234,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train previous_rfx_samples <- NULL previous_forest_samples_mean <- NULL previous_forest_samples_variance <- NULL + previous_model_num_samples <- 0 } # Determine whether conditional mean, variance, or both will be modeled @@ -708,11 +713,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) } # TODO: also initialize from previous RFX samples - # if (has_rfx) { - # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) - # } + if (has_rfx) { + resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] diff --git a/R/bcf.R b/R/bcf.R index 1db9c24c..fc7ff109 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -263,6 +263,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id b_forest <- variance_forest_params_updated$var_forest_prior_scale keep_vars_variance <- variance_forest_params_updated$keep_vars drop_vars_variance <- variance_forest_params_updated$drop_vars + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + } + } # Override keep_gfr if there are no MCMC samples if (num_mcmc == 0) keep_gfr <- T @@ -300,6 +307,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_b_1_samples <- NULL previous_b_0_samples <- NULL } + previous_model_num_samples <- previous_bcf_model$model_params$num_samples + if (previous_model_warmstart_sample_num >= previous_model_num_samples) { + stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + } } else { previous_y_bar <- NULL previous_y_scale <- NULL @@ -1004,11 +1015,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } # TODO: also initialize from previous RFX samples - # if (has_rfx) { - # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) - # } + if (has_rfx) { + resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R new file mode 100644 index 00000000..b4eb329d --- /dev/null +++ b/test/R/testthat/test-bcf.R @@ -0,0 +1,183 @@ +test_that("MCMC BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) +}) + +test_that("GFR BART", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) +}) From b23f0cb3beaf72322e5ddd97bb4636bc15f41a7f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 00:44:23 -0600 Subject: [PATCH 14/21] Fixed remaining bug in BART / BCF R interface --- .github/workflows/r-test.yml | 2 +- R/bart.R | 13 +++-- R/bcf.R | 13 +++-- cran-bootstrap.R | 43 +++++++++------ test/R/testthat/test-bart.R | 103 +++++++++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 26 deletions(-) diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index cf2fb148..0a8464bc 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -39,7 +39,7 @@ jobs: - name: Create a CRAN-ready version of the R package run: | - Rscript cran-bootstrap.R 0 0 + Rscript cran-bootstrap.R 0 0 1 - uses: r-lib/actions/check-r-package@v2 with: diff --git a/R/bart.R b/R/bart.R index 4952f96d..fca79606 100644 --- a/R/bart.R +++ b/R/bart.R @@ -199,7 +199,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (num_mcmc == 0) keep_gfr <- T # Check if previous model JSON is provided and parse it if so - # TODO: check that `previous_model_warmstart_sample_num` is <= the number of samples in this previous model has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { previous_bart_model <- createBARTModelFromJsonString(previous_model_json) @@ -712,10 +711,16 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) } - # TODO: also initialize from previous RFX samples if (has_rfx) { - resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + if (is.null(previous_rfx_samples)) { + warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } else { + resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { diff --git a/R/bcf.R b/R/bcf.R index fc7ff109..d6f7a933 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -275,7 +275,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (num_mcmc == 0) keep_gfr <- T # Check if previous model JSON is provided and parse it if so - # TODO: check that `previous_model_warmstart_sample_num` is <= the number of samples in this previous model has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) @@ -1014,10 +1013,16 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - # TODO: also initialize from previous RFX samples if (has_rfx) { - resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + if (is.null(previous_rfx_samples)) { + warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } else { + resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { diff --git a/cran-bootstrap.R b/cran-bootstrap.R index 3a47f19f..f8c22e80 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -14,25 +14,28 @@ # include_vignettes : 1 to include the vignettes folder in the R package subfolder # 0 to exclude vignettes (overriden to 1 if pkgdown_build = 1 below) # -# pkgdown_build : 1 to include pkgdown specific files (R_Readm) -# 0 to exclude vignettes +# pkgdown_build : 1 to include pkgdown specific files (R_README.md, _pkgdown.yml) +# 0 to exclude pkgdown specific files +# +# include_tests : 1 to include unit tests +# 0 to exclude unit tests # # Run this script from the command line via # -# Explicitly include vignettes and build pkgdown site -# --------------------------------------------------- -# Rscript cran-bootstrap.R 1 1 +# Explicitly include vignettes and unit tests and build pkgdown site +# ------------------------------------------------------------------ +# Rscript cran-bootstrap.R 1 1 1 # -# Explicitly include vignettes but don't build pkgdown site -# --------------------------------------------------------- -# Rscript cran-bootstrap.R 1 0 +# Explicitly include vignettes and unit tests but don't build pkgdown site +# ------------------------------------------------------------------------ +# Rscript cran-bootstrap.R 1 0 1 # -# Explicitly exclude vignettes and don't build pkgdown site -# --------------------------------------------------------- -# Rscript cran-bootstrap.R 0 0 +# Explicitly exclude vignettes and unit tests and don't build pkgdown site +# ------------------------------------------------------------------------ +# Rscript cran-bootstrap.R 0 0 0 # -# Exclude vignettes and pkgdown by default -# ---------------------------------------- +# Exclude vignettes, unit tests, and pkgdown by default +# ----------------------------------------------------- # Rscript cran-bootstrap.R # Unpack command line arguments @@ -40,9 +43,11 @@ args <- commandArgs(trailingOnly = T) if (length(args) > 0){ include_vignettes <- as.logical(as.integer(args[1])) pkgdown_build <- as.logical(as.integer(args[2])) + include_tests <- as.logical(as.integer(args[3])) } else{ include_vignettes <- F pkgdown_build <- F + include_tests <- F } # Create the stochtree_cran folder @@ -95,10 +100,14 @@ if (pkgdown_build) { } # Handle tests separately (move from test/R/ folder to tests/ folder) -test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) -test_files_dst <- file.path(cran_dir, gsub("test/R", "tests", test_files_src)) -pkg_core_files <- c(pkg_core_files, test_files_src) -pkg_core_files_dst <- c(pkg_core_files_dst, test_files_dst) +if (include_tests) { + test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) + test_files_dst <- file.path(cran_dir, gsub("test/R", "tests", test_files_src)) + pkg_core_files <- c(pkg_core_files, test_files_src) + pkg_core_files_dst <- c(pkg_core_files_dst, test_files_dst) +} + +# Copy over all core package files if (all(file.exists(pkg_core_files))) { n_removed <- suppressWarnings(sum(file.remove(pkg_core_files_dst))) if (n_removed > 0) { diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index b0372236..9a885b91 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -129,3 +129,106 @@ test_that("GFR BART", { general_params = general_param_list) ) }) + +test_that("Warmstart BART", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 0, + general_params = general_param_list) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + rfx_group_ids <- sample(1:2, size = n, replace = T) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 0, + general_params = general_param_list) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 4, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) +}) From e64a1043c9416cbeb13a8b03f5f98be7ddc3eb3d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 01:09:15 -0600 Subject: [PATCH 15/21] Updated copyright details --- inst/COPYRIGHTS | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/inst/COPYRIGHTS b/inst/COPYRIGHTS index 6d246808..71270347 100644 --- a/inst/COPYRIGHTS +++ b/inst/COPYRIGHTS @@ -1,15 +1,34 @@ stochtree Copyright 2023-2025 stochtree contributors +Several stochtree C++ header and source files include or are inspired by code +in several open-source decision tree libraries: xgboost, LightGBM, and treelite. +Copyright and license information for each of these three projects are detailed +further below and in comments in each of the files. +File: src/include/stochtree/category_tracker.h [xgboost] +File: src/include/stochtree/common.h [xgboost] +File: src/include/stochtree/ensemble.h [xgboost] +File: src/include/stochtree/io.h [LightGBM] +File: src/include/stochtree/log.h [LightGBM] +File: src/include/stochtree/meta.h [LightGBM] +File: src/include/stochtree/partition_tracker.h [LightGBM, xgboost] +File: src/include/stochtree/tree.h [xgboost, treelite] + This project includes software from the xgboost project (Apache, 2.0). * Copyright 2015-2024, XGBoost Contributors This project includes software from the LightGBM project (MIT). * Copyright (c) 2016 Microsoft Corporation +This project includes software from the treelite project (Apache, 2.0). +* Copyright (c) 2017-2023 by [treelite] Contributors + This project includes software from the fast_double_parser project (Apache, 2.0). * Copyright (c) Daniel Lemire +This project includes software from the JSON for Modern C++ project (MIT). +* Copyright © 2013-2025 Niels Lohmann + This project includes software from the Eigen project (MPL, 2.0), whose headers carry the following copyrights: File: Eigen/Core Copyright (C) 2008 Gael Guennebaud From a65f6066186576e9f5cd0b07440e6fd45dcefdf8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 01:23:16 -0600 Subject: [PATCH 16/21] Updated docs and incorporated CRAN check feedback --- R/bart.R | 10 +++- R/bcf.R | 53 ++++++++++++++------- R/forest.R | 19 +++++--- R/kernel.R | 12 ++--- R/model.R | 6 ++- man/createBCFModelFromCombinedJson.Rd | 6 ++- man/createBCFModelFromCombinedJsonString.Rd | 6 ++- man/createBCFModelFromJson.Rd | 6 ++- man/createBCFModelFromJsonFile.Rd | 6 ++- man/createBCFModelFromJsonString.Rd | 6 ++- man/createForestModel.Rd | 6 ++- man/getRandomEffectSamples.bartmodel.Rd | 6 ++- man/getRandomEffectSamples.bcfmodel.Rd | 6 ++- man/predict.bcfmodel.Rd | 3 +- man/resetForestModel.Rd | 19 +++++--- man/saveBCFModelToJson.Rd | 6 ++- man/saveBCFModelToJsonFile.Rd | 6 ++- 17 files changed, 120 insertions(+), 62 deletions(-) diff --git a/R/bart.R b/R/bart.R index fca79606..2e261d47 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1133,8 +1133,10 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, -#' rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_group_ids_test = rfx_group_ids_test, +#' rfx_basis_train = rfx_basis_train, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' rfx_samples <- getRandomEffectSamples(bart_model) getRandomEffectSamples.bartmodel <- function(object, ...){ @@ -1190,6 +1192,10 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ saveBARTModelToJson <- function(object){ jsonobj <- createCppJson() + if (!inherits(object, "bartmodel")) { + stop("`object` must be a BART model") + } + if (is.null(object$model_params)) { stop("This BCF model has not yet been sampled") } diff --git a/R/bcf.R b/R/bcf.R index d6f7a933..5df8ca26 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1409,7 +1409,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id #' mu_train <- mu_x[train_inds] #' tau_test <- tau_x[test_inds] #' tau_train <- tau_x[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train) +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) #' plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", #' ylab = "actual", main = "Prognostic function") @@ -1597,9 +1598,11 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1686,9 +1689,11 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1697,7 +1702,7 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ saveBCFModelToJson <- function(object){ jsonobj <- createCppJson() - if (class(object) != "bcfmodel") { + if (!inherits(object, "bcfmodel")) { stop("`object` must be a BCF model") } @@ -1849,9 +1854,11 @@ saveBCFModelToJson <- function(object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -2003,9 +2010,11 @@ saveBCFModelToJsonString <- function(object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -2166,9 +2175,11 @@ createBCFModelFromJson <- function(json_object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -2245,9 +2256,11 @@ createBCFModelFromJsonFile <- function(json_filename){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json <- saveBCFModelToJsonString(bcf_model) @@ -2323,9 +2336,11 @@ createBCFModelFromJsonString <- function(json_string){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json_list <- list(saveBCFModelToJson(bcf_model)) @@ -2533,9 +2548,11 @@ createBCFModelFromCombinedJson <- function(json_object_list){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) diff --git a/R/forest.R b/R/forest.R index 06491f7d..6af83bbe 100644 --- a/R/forest.R +++ b/R/forest.R @@ -889,17 +889,24 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' outcome <- createOutcome(y) #' rng <- createCppRNG(1234) #' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, -#' alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, -#' variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, -#' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_observations=n, +#' num_features=p, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, +#' variable_weights=variable_weights, +#' cutpoint_grid_size=cutpoint_grid_size, +#' leaf_model_type=leaf_model, +#' leaf_model_scale=leaf_scale) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -#' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +#' forest_samples <- createForestSamples(num_trees, leaf_dimension, +#' is_leaf_constant, is_exponentiated) #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( #' forest_dataset, outcome, forest_samples, active_forest, -#' rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE +#' rng, forest_model_config, global_model_config, +#' keep_forest = TRUE, gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/R/kernel.R b/R/kernel.R index becbb43b..74eedd64 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -48,8 +48,7 @@ #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { # Extract relevant forest container - object_name <- class(model_object)[1] - stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples")) + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples")) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) @@ -143,8 +142,8 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, #' computeForestLeafVariances(bart_model, "mean", c(1,3,5)) computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) { # Extract relevant forest container - stopifnot(class(model_object) %in% c("bartmodel", "bcfmodel")) - model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf") + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel")))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { @@ -234,9 +233,8 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU #' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9)) computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { # Extract relevant forest container - object_name <- class(model_object)[1] - stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples")) - model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples")) + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { diff --git a/R/model.R b/R/model.R index 7cde7274..b8da4cf2 100644 --- a/R/model.R +++ b/R/model.R @@ -205,8 +205,10 @@ createCppRNG <- function(random_seed = -1){ #' feature_types <- as.integer(rep(0, p)) #' X <- matrix(runif(n*p), ncol = p) #' forest_dataset <- createForestDataset(X) -#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, -#' num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_features=p, +#' num_observations=n, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, #' max_depth=max_depth, leaf_model_type=1) #' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index e374c311..b1fb9ac9 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -70,9 +70,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json_list <- list(saveBCFModelToJson(bcf_model)) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index f1853d7f..988c7346 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -70,9 +70,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index 602db813..2cde726a 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2f9be821..cb83403e 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 7e27f9bb..1cd567ca 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -69,9 +69,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json <- saveBCFModelToJsonString(bcf_model) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index 836f627e..d9000925 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,8 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 2ff00687..9f273732 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -52,8 +52,10 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index ca03ffe4..410f44c4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -73,9 +73,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index 3fd2f1a4..c0b14eb5 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -78,7 +78,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train) +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train) preds <- predict(bcf_model, X_test, Z_test, pi_test) plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index 06e1721e..f0fec6ca 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,17 +48,24 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, - alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, - variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, + leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +forest_samples <- createForestSamples(num_trees, leaf_dimension, + is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE + rng, forest_model_config, global_model_config, + keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index 89598334..171d0c53 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -69,9 +69,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index 14417564..2c8ea980 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, From 5fbad919015942912fcb0b5a461a3e3ac7a58d7f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 01:28:48 -0600 Subject: [PATCH 17/21] Fixed typo --- R/kernel.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/kernel.R b/R/kernel.R index 74eedd64..0e79b47e 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -49,7 +49,7 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { # Extract relevant forest container stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) - model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples")) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { From 2337034c9040c5aee6f0501c2d1d849c35e1bf24 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 01:39:47 -0600 Subject: [PATCH 18/21] Updated docs and package details --- DESCRIPTION | 2 +- LICENSE | 4 ++-- LICENSE.md | 2 +- R/bcf.R | 6 ++++-- man/saveBCFModelToJsonString.Rd | 6 ++++-- man/stochtree-package.Rd | 2 +- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6b3277dc..83ee1632 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: stochtree -Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference +Title: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference Version: 0.1.0 Authors@R: c( diff --git a/LICENSE b/LICENSE index 163f07f2..1941a4d0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,2 +1,2 @@ -YEAR: 2024 -COPYRIGHT HOLDER: stochtree authors \ No newline at end of file +YEAR: 2025 +COPYRIGHT HOLDER: stochtree contributors \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md index 3c81b245..5e7f9a94 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,6 +1,6 @@ # MIT License -Copyright (c) 2024 stochtree authors +Copyright (c) 2023-2025 stochtree authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/R/bcf.R b/R/bcf.R index 5df8ca26..22ddb7b0 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1932,9 +1932,11 @@ saveBCFModelToJsonFile <- function(object, filename){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index e1d6769c..3c0bdee1 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -69,9 +69,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/stochtree-package.Rd b/man/stochtree-package.Rd index 0377fb91..49553eca 100644 --- a/man/stochtree-package.Rd +++ b/man/stochtree-package.Rd @@ -4,7 +4,7 @@ \name{stochtree-package} \alias{stochtree} \alias{stochtree-package} -\title{stochtree: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference} +\title{stochtree: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference} \description{ Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference. } From 1ccdd1bc7bf3262867de8ab4b98d3edf7e471160 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 02:30:43 -0600 Subject: [PATCH 19/21] Fixed bug and updated DESCRIPTION --- DESCRIPTION | 6 +++++- R/bart.R | 2 +- R/bcf.R | 2 +- cran-bootstrap.R | 17 +++++++++++++++++ man/stochtree-package.Rd | 2 +- 5 files changed, 25 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 83ee1632..7d3bdc53 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -10,7 +10,11 @@ Authors@R: person("Jingyu", "He", role = "aut"), person("stochtree contributors", role = c("cph")) ) -Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference. +Description: Flexible stochastic tree ensemble software. Robust implementations of + Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) ) + for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) ) + for causal inference. Enables model serialization and parallel sampling + and provides a low-level interface for custom stochastic forest samplers. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) diff --git a/R/bart.R b/R/bart.R index 2e261d47..ba33ae84 100644 --- a/R/bart.R +++ b/R/bart.R @@ -222,7 +222,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train previous_rfx_samples <- previous_bart_model$rfx_samples } else previous_rfx_samples <- NULL previous_model_num_samples <- previous_bart_model$model_params$num_samples - if (previous_model_warmstart_sample_num >= previous_model_num_samples) { + if (previous_model_warmstart_sample_num > previous_model_num_samples) { stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") } } else { diff --git a/R/bcf.R b/R/bcf.R index 22ddb7b0..a4e58e49 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -307,7 +307,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_b_0_samples <- NULL } previous_model_num_samples <- previous_bcf_model$model_params$num_samples - if (previous_model_warmstart_sample_num >= previous_model_num_samples) { + if (previous_model_warmstart_sample_num > previous_model_num_samples) { stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") } } else { diff --git a/cran-bootstrap.R b/cran-bootstrap.R index f8c22e80..e06eba11 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -151,6 +151,23 @@ if (!include_vignettes) { writeLines(description_lines, cran_description) } +# Remove testthat deps from DESCRIPTION if no tests +if (!include_tests) { + cran_description <- file.path(cran_dir, "DESCRIPTION") + description_lines <- readLines(cran_description) + if (include_vignettes) { + suggestion_match <- grep("testthat (>= 3.0.0)", description_lines) + suggestion_lines <- suggestion_match + } else { + suggestion_begin <- grep("Suggests:", description_lines) + suggestion_end <- grep("SystemRequirements:", description_lines) - 1 + suggestion_lines <- suggestion_begin:suggestion_end + } + testthat_config_line <- grep("Config/testthat/edition:", description_lines) + description_lines <- description_lines[-c(suggestion_lines, testthat_config_line)] + writeLines(description_lines, cran_description) +} + # Remove vignettes from _pkgdown.yml if no vignettes if ((!include_vignettes) & (pkgdown_build)) { pkgdown_yml <- file.path(cran_dir, "_pkgdown.yml") diff --git a/man/stochtree-package.Rd b/man/stochtree-package.Rd index 49553eca..6d82b32a 100644 --- a/man/stochtree-package.Rd +++ b/man/stochtree-package.Rd @@ -6,7 +6,7 @@ \alias{stochtree-package} \title{stochtree: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference} \description{ -Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference. +Flexible stochastic tree ensemble software. Robust implementations of Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) \doi{10.1214/09-AOAS285}) for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) \doi{10.1214/19-BA1195}) for causal inference. Enables model serialization and parallel sampling and provides a low-level interface for custom stochastic forest samplers. } \seealso{ Useful links: From f2bfe219b4021bce91927795c2c40c1ed08c5efd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 15:49:30 -0600 Subject: [PATCH 20/21] Fixed minor BCF bug --- R/bcf.R | 5 +- test/R/testthat/test-bcf.R | 168 +++++++++++++++++++++++++-- test/R/testthat/test-serialization.R | 6 +- 3 files changed, 163 insertions(+), 16 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index a4e58e49..9118333c 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -286,8 +286,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_forest_samples_variance <- previous_bcf_model$forests_variance } else previous_forest_samples_variance <- NULL if (previous_bcf_model$model_params$sample_sigma_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_samples*( - previous_var_scale / (previous_y_scale*previous_y_scale) + previous_global_var_samples <- previous_bcf_model$sigma2_samples / ( + previous_y_scale*previous_y_scale ) } else previous_global_var_samples <- NULL if (previous_bcf_model$model_params$sample_sigma_leaf_mu) { @@ -313,7 +313,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } else { previous_y_bar <- NULL previous_y_scale <- NULL - previous_var_scale <- NULL previous_global_var_samples <- NULL previous_leaf_var_mu_samples <- NULL previous_leaf_var_tau_samples <- NULL diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index b4eb329d..0a34c37c 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -81,7 +81,7 @@ test_that("MCMC BCF", { ) }) -test_that("GFR BART", { +test_that("GFR BCF", { skip_on_cran() # Generate simulated data @@ -90,21 +90,21 @@ test_that("GFR BART", { X <- matrix(runif(n*p), ncol = p) mu_X <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_X <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_X <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_X) noise_sd <- 1 @@ -181,3 +181,151 @@ test_that("GFR BART", { num_mcmc = 10, general_params = general_param_list) ) }) + +test_that("Warmstart BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, + num_mcmc = 0, general_params = general_param_list) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = T) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, + num_mcmc = 0, general_params = general_param_list) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) +}) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index e640d3f8..0d78957f 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -7,9 +7,9 @@ test_that("BART Serialization", { X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) From 1e73f92a19e20610fa86692d9a71b81152b6e628 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 5 Feb 2025 16:02:23 -0600 Subject: [PATCH 21/21] Updated Changelog --- NEWS.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index aa0f54d0..397c25ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,16 @@ # stochtree 0.1.0 -* Initial CRAN submission. +* Initial release on CRAN. +* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR) +* High-level model types supported: + * Supervised learning with constant leaves or user-specified leaf regression models + * Causal effect estimation with binary, continuous, or multivariate treatments +* Additional high-level modeling features: + * Forest-based variance function estimation (heteroskedasticity) + * Additive (univariate or multivariate) group random effects + * Multi-chain sampling and support for parallelism + * "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm + * Automated preprocessing / handling of categorical variables +* Low-level interface: + * Ability to combine a forest sampler with other (additive) model terms, without using C++ + * Combine and sample an arbitrary number of forests or random effects terms