Skip to content

Adding Probit link to BART and BCF #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Changelog

# stochtree 0.1.2

## New Features

* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164))

## Bug Fixes

* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
* Avoid using covariate preprocessor in `computeForestLeafIndices` R function when a `ForestSamples` object is provided (instead of a `bartmodel` or `bcfmodel` object)

# stochtree 0.1.1

## Bug Fixes

* Fixed initialization bug in several R package code examples for random effects models

# stochtree 0.1.0

Initial "alpha" release

## New Features

* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR)
* High-level model types supported:
* Supervised learning with constant leaves or user-specified leaf regression models
* Causal effect estimation with binary or continuous treatments
* Additional high-level modeling features:
* Forest-based variance function estimation (heteroskedasticity)
* Additive (univariate or multivariate) group random effects
* Multi-chain sampling and support for parallelism
* "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm
* Automated preprocessing / handling of categorical variables
* Low-level interface:
* Ability to combine a forest sampler with other (additive) model terms, without using C++
* Combine and sample an arbitrary number of forests or random effects terms
8 changes: 7 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# stochtree 0.1.2

## New Features

* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164))

## Bug Fixes

* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided (rather than a `bartmodel` or `bcfmodel` object)

# stochtree 0.1.1

Expand Down
167 changes: 136 additions & 31 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand Down Expand Up @@ -125,7 +126,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
min_samples_leaf = 5, max_depth = 10,
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL
keep_vars = NULL, drop_vars = NULL,
probit_outcome_model = FALSE
)
mean_forest_params_updated <- preprocessParams(
mean_forest_params_default, mean_forest_params
Expand Down Expand Up @@ -173,6 +175,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
keep_vars_mean <- mean_forest_params_updated$keep_vars
drop_vars_mean <- mean_forest_params_updated$drop_vars
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model

# 3. Variance forest parameters
num_trees_variance <- variance_forest_params_updated$num_trees
Expand Down Expand Up @@ -462,50 +465,118 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

# Determine whether a test set is provided
has_test = !is.null(X_test)

# Preliminary runtime checks for probit link
if (!include_mean_forest) {
probit_outcome_model <- FALSE
}
if (probit_outcome_model) {
if (!(length(unique(y_train)) == 2)) {
stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values")
}
unique_outcomes <- sort(unique(y_train))
if (!(all(unique_outcomes == c(0,1)))) {
stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1")
}
if (include_variance_forest) {
stop("We do not support heteroskedasticity with a probit link")
}
if (sample_sigma_global) {
warning("Global error variance will not be sampled with a probit link as it is fixed at 1")
sample_sigma_global <- F
}
}

# Standardize outcome separately for test and train
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
# Handle standardization, prior calibration, and initialization of forest
# differently for binary and continuous outcomes
if (probit_outcome_model) {
# Compute a probit-scale offset and fix scale to 1
y_bar_train <- qnorm(mean(y_train))
y_std_train <- 1
}
resid_train <- (y_train-y_bar_train)/y_std_train

# Compute initial value of root nodes in mean forest
init_val_mean <- mean(resid_train)

# Calibrate priors for sigma^2 and tau
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
# Set a pseudo outcome by subtracting mean(y_train) from y_train
resid_train <- y_train - mean(y_train)

# Set initial values of root nodes to 0.0 (in probit scale)
init_val_mean <- 0.0

# Calibrate priors for sigma^2 and tau
# Set sigma2_init to 1, ignoring default provided
sigma2_init <- 1.0
# Skip variance_forest_init, since variance forests are not supported with probit link
b_leaf <- 1/(num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
} else {
current_leaf_scale <- sigma_leaf_init
}
} else {
current_leaf_scale <- sigma_leaf_init
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
current_sigma2 <- sigma2_init
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
# Only standardize if user requested
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
y_std_train <- 1
}

# Compute standardized outcome
resid_train <- (y_train-y_bar_train)/y_std_train

# Compute initial value of root nodes in mean forest
init_val_mean <- mean(resid_train)

# Calibrate priors for sigma^2 and tau
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
} else {
current_leaf_scale <- sigma_leaf_init
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
} else {
current_leaf_scale <- sigma_leaf_init
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
current_sigma2 <- sigma2_init
}
current_sigma2 <- sigma2_init


# Determine leaf model type
if (!has_basis) leaf_model_mean_forest <- 0
else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1
Expand Down Expand Up @@ -634,7 +705,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
# Initialize the leaves of each tree in the variance forest
if (include_variance_forest) {
active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init)

}

# Run GFR (warm start) if specified
Expand All @@ -652,6 +722,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

if (include_mean_forest) {
if (probit_outcome_model) {
# Sample latent probit variable, z | -
forest_pred <- active_forest_mean$predict(forest_dataset_train)
mu0 <- forest_pred[y_train == 0]
mu1 <- forest_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train==0] <- mu0 + qnorm(u0)
resid_train[y_train==1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - forest_pred)
}

# Sample mean forest
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
Expand Down Expand Up @@ -791,6 +876,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

if (include_mean_forest) {
if (probit_outcome_model) {
# Sample latent probit variable, z | -
forest_pred <- active_forest_mean$predict(forest_dataset_train)
mu0 <- forest_pred[y_train == 0]
mu1 <- forest_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train==0] <- mu0 + qnorm(u0)
resid_train[y_train==1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - forest_pred)
}

forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
Expand Down Expand Up @@ -915,7 +1014,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
"sample_sigma_global" = sample_sigma_global,
"sample_sigma_leaf" = sample_sigma_leaf,
"include_mean_forest" = include_mean_forest,
"include_variance_forest" = include_variance_forest
"include_variance_forest" = include_variance_forest,
"probit_outcome_model" = probit_outcome_model
)
result <- list(
"model_params" = model_params,
Expand Down Expand Up @@ -1257,6 +1357,7 @@ saveBARTModelToJson <- function(object){
jsonobj$add_scalar("num_chains", object$model_params$num_chains)
jsonobj$add_scalar("keep_every", object$model_params$keep_every)
jsonobj$add_boolean("requires_basis", object$model_params$requires_basis)
jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model)
if (object$model_params$sample_sigma_global) {
jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters")
}
Expand Down Expand Up @@ -1448,6 +1549,8 @@ createBARTModelFromJson <- function(json_object){
model_params[["num_chains"]] <- json_object$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object$get_scalar("keep_every")
model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model")

output[["model_params"]] <- model_params

# Unpack sampled parameters
Expand Down Expand Up @@ -1650,6 +1753,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){
model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis")
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")

Expand Down Expand Up @@ -1805,6 +1909,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")

# Combine values that are sample-specific
for (i in 1:length(json_object_list)) {
Expand Down
Loading
Loading