Skip to content

Commit 6dae5cb

Browse files
authored
Merge pull request #164 from StochasticTree/probit-bart-bcf
Adding Probit link to BART and BCF
2 parents b3376bf + 58e6d22 commit 6dae5cb

17 files changed

+1529
-188
lines changed

CHANGELOG.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Changelog
2+
3+
# stochtree 0.1.2
4+
5+
## New Features
6+
7+
* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164))
8+
9+
## Bug Fixes
10+
11+
* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
12+
* Avoid using covariate preprocessor in `computeForestLeafIndices` R function when a `ForestSamples` object is provided (instead of a `bartmodel` or `bcfmodel` object)
13+
14+
# stochtree 0.1.1
15+
16+
## Bug Fixes
17+
18+
* Fixed initialization bug in several R package code examples for random effects models
19+
20+
# stochtree 0.1.0
21+
22+
Initial "alpha" release
23+
24+
## New Features
25+
26+
* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR)
27+
* High-level model types supported:
28+
* Supervised learning with constant leaves or user-specified leaf regression models
29+
* Causal effect estimation with binary or continuous treatments
30+
* Additional high-level modeling features:
31+
* Forest-based variance function estimation (heteroskedasticity)
32+
* Additive (univariate or multivariate) group random effects
33+
* Multi-chain sampling and support for parallelism
34+
* "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm
35+
* Automated preprocessing / handling of categorical variables
36+
* Low-level interface:
37+
* Ability to combine a forest sampler with other (additive) model terms, without using C++
38+
* Combine and sample an arbitrary number of forests or random effects terms

NEWS.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# stochtree 0.1.2
22

3+
## New Features
4+
5+
* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164))
6+
7+
## Bug Fixes
8+
39
* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
4-
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided
10+
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided (rather than a `bartmodel` or `bcfmodel` object)
511

612
# stochtree 0.1.1
713

R/bart.R

Lines changed: 136 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
5959
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
6060
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
61+
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
6162
#'
6263
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
6364
#'
@@ -125,7 +126,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
125126
min_samples_leaf = 5, max_depth = 10,
126127
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
127128
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
128-
keep_vars = NULL, drop_vars = NULL
129+
keep_vars = NULL, drop_vars = NULL,
130+
probit_outcome_model = FALSE
129131
)
130132
mean_forest_params_updated <- preprocessParams(
131133
mean_forest_params_default, mean_forest_params
@@ -173,6 +175,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
173175
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
174176
keep_vars_mean <- mean_forest_params_updated$keep_vars
175177
drop_vars_mean <- mean_forest_params_updated$drop_vars
178+
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model
176179

177180
# 3. Variance forest parameters
178181
num_trees_variance <- variance_forest_params_updated$num_trees
@@ -462,50 +465,118 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
462465

463466
# Determine whether a test set is provided
464467
has_test = !is.null(X_test)
468+
469+
# Preliminary runtime checks for probit link
470+
if (!include_mean_forest) {
471+
probit_outcome_model <- FALSE
472+
}
473+
if (probit_outcome_model) {
474+
if (!(length(unique(y_train)) == 2)) {
475+
stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values")
476+
}
477+
unique_outcomes <- sort(unique(y_train))
478+
if (!(all(unique_outcomes == c(0,1)))) {
479+
stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1")
480+
}
481+
if (include_variance_forest) {
482+
stop("We do not support heteroskedasticity with a probit link")
483+
}
484+
if (sample_sigma_global) {
485+
warning("Global error variance will not be sampled with a probit link as it is fixed at 1")
486+
sample_sigma_global <- F
487+
}
488+
}
465489

466-
# Standardize outcome separately for test and train
467-
if (standardize) {
468-
y_bar_train <- mean(y_train)
469-
y_std_train <- sd(y_train)
470-
} else {
471-
y_bar_train <- 0
490+
# Handle standardization, prior calibration, and initialization of forest
491+
# differently for binary and continuous outcomes
492+
if (probit_outcome_model) {
493+
# Compute a probit-scale offset and fix scale to 1
494+
y_bar_train <- qnorm(mean(y_train))
472495
y_std_train <- 1
473-
}
474-
resid_train <- (y_train-y_bar_train)/y_std_train
475-
476-
# Compute initial value of root nodes in mean forest
477-
init_val_mean <- mean(resid_train)
478496

479-
# Calibrate priors for sigma^2 and tau
480-
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
481-
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
482-
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
483-
if (has_basis) {
484-
if (ncol(leaf_basis_train) > 1) {
485-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
486-
if (!is.matrix(sigma_leaf_init)) {
487-
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
497+
# Set a pseudo outcome by subtracting mean(y_train) from y_train
498+
resid_train <- y_train - mean(y_train)
499+
500+
# Set initial values of root nodes to 0.0 (in probit scale)
501+
init_val_mean <- 0.0
502+
503+
# Calibrate priors for sigma^2 and tau
504+
# Set sigma2_init to 1, ignoring default provided
505+
sigma2_init <- 1.0
506+
# Skip variance_forest_init, since variance forests are not supported with probit link
507+
b_leaf <- 1/(num_trees_mean)
508+
if (has_basis) {
509+
if (ncol(leaf_basis_train) > 1) {
510+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train))
511+
if (!is.matrix(sigma_leaf_init)) {
512+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
513+
} else {
514+
current_leaf_scale <- sigma_leaf_init
515+
}
488516
} else {
489-
current_leaf_scale <- sigma_leaf_init
517+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
518+
if (!is.matrix(sigma_leaf_init)) {
519+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
520+
} else {
521+
current_leaf_scale <- sigma_leaf_init
522+
}
490523
}
491524
} else {
492-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
525+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
493526
if (!is.matrix(sigma_leaf_init)) {
494527
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
495528
} else {
496529
current_leaf_scale <- sigma_leaf_init
497530
}
498531
}
532+
current_sigma2 <- sigma2_init
499533
} else {
500-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
501-
if (!is.matrix(sigma_leaf_init)) {
502-
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
534+
# Only standardize if user requested
535+
if (standardize) {
536+
y_bar_train <- mean(y_train)
537+
y_std_train <- sd(y_train)
538+
} else {
539+
y_bar_train <- 0
540+
y_std_train <- 1
541+
}
542+
543+
# Compute standardized outcome
544+
resid_train <- (y_train-y_bar_train)/y_std_train
545+
546+
# Compute initial value of root nodes in mean forest
547+
init_val_mean <- mean(resid_train)
548+
549+
# Calibrate priors for sigma^2 and tau
550+
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
551+
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
552+
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
553+
if (has_basis) {
554+
if (ncol(leaf_basis_train) > 1) {
555+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
556+
if (!is.matrix(sigma_leaf_init)) {
557+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
558+
} else {
559+
current_leaf_scale <- sigma_leaf_init
560+
}
561+
} else {
562+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean))
563+
if (!is.matrix(sigma_leaf_init)) {
564+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
565+
} else {
566+
current_leaf_scale <- sigma_leaf_init
567+
}
568+
}
503569
} else {
504-
current_leaf_scale <- sigma_leaf_init
570+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean))
571+
if (!is.matrix(sigma_leaf_init)) {
572+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
573+
} else {
574+
current_leaf_scale <- sigma_leaf_init
575+
}
505576
}
577+
current_sigma2 <- sigma2_init
506578
}
507-
current_sigma2 <- sigma2_init
508-
579+
509580
# Determine leaf model type
510581
if (!has_basis) leaf_model_mean_forest <- 0
511582
else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1
@@ -634,7 +705,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
634705
# Initialize the leaves of each tree in the variance forest
635706
if (include_variance_forest) {
636707
active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init)
637-
638708
}
639709

640710
# Run GFR (warm start) if specified
@@ -652,6 +722,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
652722
}
653723

654724
if (include_mean_forest) {
725+
if (probit_outcome_model) {
726+
# Sample latent probit variable, z | -
727+
forest_pred <- active_forest_mean$predict(forest_dataset_train)
728+
mu0 <- forest_pred[y_train == 0]
729+
mu1 <- forest_pred[y_train == 1]
730+
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
731+
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
732+
resid_train[y_train==0] <- mu0 + qnorm(u0)
733+
resid_train[y_train==1] <- mu1 + qnorm(u1)
734+
735+
# Update outcome
736+
outcome_train$update_data(resid_train - forest_pred)
737+
}
738+
739+
# Sample mean forest
655740
forest_model_mean$sample_one_iteration(
656741
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
657742
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
@@ -791,6 +876,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
791876
}
792877

793878
if (include_mean_forest) {
879+
if (probit_outcome_model) {
880+
# Sample latent probit variable, z | -
881+
forest_pred <- active_forest_mean$predict(forest_dataset_train)
882+
mu0 <- forest_pred[y_train == 0]
883+
mu1 <- forest_pred[y_train == 1]
884+
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
885+
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
886+
resid_train[y_train==0] <- mu0 + qnorm(u0)
887+
resid_train[y_train==1] <- mu1 + qnorm(u1)
888+
889+
# Update outcome
890+
outcome_train$update_data(resid_train - forest_pred)
891+
}
892+
794893
forest_model_mean$sample_one_iteration(
795894
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
796895
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
@@ -915,7 +1014,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
9151014
"sample_sigma_global" = sample_sigma_global,
9161015
"sample_sigma_leaf" = sample_sigma_leaf,
9171016
"include_mean_forest" = include_mean_forest,
918-
"include_variance_forest" = include_variance_forest
1017+
"include_variance_forest" = include_variance_forest,
1018+
"probit_outcome_model" = probit_outcome_model
9191019
)
9201020
result <- list(
9211021
"model_params" = model_params,
@@ -1257,6 +1357,7 @@ saveBARTModelToJson <- function(object){
12571357
jsonobj$add_scalar("num_chains", object$model_params$num_chains)
12581358
jsonobj$add_scalar("keep_every", object$model_params$keep_every)
12591359
jsonobj$add_boolean("requires_basis", object$model_params$requires_basis)
1360+
jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model)
12601361
if (object$model_params$sample_sigma_global) {
12611362
jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters")
12621363
}
@@ -1448,6 +1549,8 @@ createBARTModelFromJson <- function(json_object){
14481549
model_params[["num_chains"]] <- json_object$get_scalar("num_chains")
14491550
model_params[["keep_every"]] <- json_object$get_scalar("keep_every")
14501551
model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis")
1552+
model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model")
1553+
14511554
output[["model_params"]] <- model_params
14521555

14531556
# Unpack sampled parameters
@@ -1650,6 +1753,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){
16501753
model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
16511754
model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis")
16521755
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
1756+
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
16531757
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
16541758
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
16551759

@@ -1805,6 +1909,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
18051909
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
18061910
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
18071911
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
1912+
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
18081913

18091914
# Combine values that are sample-specific
18101915
for (i in 1:length(json_object_list)) {

0 commit comments

Comments
 (0)