Skip to content

Commit 2b52159

Browse files
authored
Merge pull request #52 from StochasticTree/bcf-feature-subsets-mu-tau
Using subsets of features in the prognostic and treatment forests in BCF
2 parents 9aac95c + 4b827b4 commit 2b52159

File tree

11 files changed

+629
-227
lines changed

11 files changed

+629
-227
lines changed

R/bart.R

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
3434
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
3535
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
36-
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as 0.5/num_trees if not set here.
36+
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
3737
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
3838
#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
39+
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
3940
#' @param num_trees Number of trees in the ensemble. Default: 200.
4041
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
4142
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
@@ -80,10 +81,19 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
8081
cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95,
8182
beta = 2.0, min_samples_leaf = 5, leaf_model = 0,
8283
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
83-
q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5,
84-
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
85-
sample_tau = T, random_seed = -1, keep_burnin = F,
86-
keep_gfr = F, verbose = F){
84+
q = 0.9, sigma2_init = NULL, variable_weights = NULL,
85+
num_trees = 200, num_gfr = 5, num_burnin = 0,
86+
num_mcmc = 100, sample_sigma = T, sample_tau = T,
87+
random_seed = -1, keep_burnin = F, keep_gfr = F,
88+
verbose = F){
89+
# Variable weight preprocessing (and initialization if necessary)
90+
if (is.null(variable_weights)) {
91+
variable_weights = rep(1/ncol(X_train), ncol(X_train))
92+
}
93+
if (any(variable_weights < 0)) {
94+
stop("variable_weights cannot have any negative weights")
95+
}
96+
8797
# Preprocess covariates
8898
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
8999
stop("X_train must be a matrix or dataframe")
@@ -93,12 +103,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
93103
stop("X_test must be a matrix or dataframe")
94104
}
95105
}
106+
if (ncol(X_train) != length(variable_weights)) {
107+
stop("length(variable_weights) must equal ncol(X_train)")
108+
}
96109
train_cov_preprocess_list <- preprocessTrainData(X_train)
97110
X_train_metadata <- train_cov_preprocess_list$metadata
98111
X_train <- train_cov_preprocess_list$data
112+
original_var_indices <- X_train_metadata$original_var_indices
99113
feature_types <- X_train_metadata$feature_types
100114
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
101115

116+
# Update variable weights
117+
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
118+
variable_weights <- variable_weights[original_var_indices]*variable_weights_adj
119+
102120
# Convert all input data to matrices if not already converted
103121
if ((is.null(dim(W_train))) && (!is.null(W_train))) {
104122
W_train <- as.matrix(W_train)
@@ -295,9 +313,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
295313
if (sample_sigma) global_var_samples <- rep(0, num_samples)
296314
if (sample_tau) leaf_scale_samples <- rep(0, num_samples)
297315

298-
# Variable selection weights
299-
variable_weights <- rep(1/ncol(X_train), ncol(X_train))
300-
301316
# Run GFR (warm start) if specified
302317
if (num_gfr > 0){
303318
gfr_indices = 1:num_gfr

R/bcf.R

Lines changed: 147 additions & 72 deletions
Large diffs are not rendered by default.

R/utils.R

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#' types. Matrices will be passed through assuming all columns are numeric.
33
#'
44
#' @param input_data Covariates, provided as either a dataframe or a matrix
5+
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
56
#'
67
#' @return List with preprocessed (unmodified) data and details on the number of each type
78
#' of variable, unique categories associated with categorical variables, and the
@@ -63,6 +64,7 @@ preprocessPredictionData <- function(input_data, metadata) {
6364
#' Returns a list including a matrix of preprocessed covariate values and associated tracking.
6465
#'
6566
#' @param input_matrix Covariate matrix.
67+
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
6668
#'
6769
#' @return List with preprocessed (unmodified) data and details on the number of each type
6870
#' of variable, unique categories associated with categorical variables, and the
@@ -97,7 +99,8 @@ preprocessTrainMatrix <- function(input_matrix) {
9799
num_ordered_cat_vars = num_ordered_cat_vars,
98100
num_unordered_cat_vars = num_unordered_cat_vars,
99101
num_numeric_vars = num_numeric_vars,
100-
numeric_vars = numeric_vars
102+
numeric_vars = numeric_vars,
103+
original_var_indices = 1:num_numeric_vars
101104
)
102105
output <- list(
103106
data = X,
@@ -139,6 +142,7 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) {
139142
#'
140143
#' @param input_df Dataframe of covariates. Users must pre-process any
141144
#' categorical variables as factors (ordered for ordered categorical).
145+
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
142146
#'
143147
#' @return List with preprocessed data and details on the number of each type
144148
#' of variable, unique categories associated with categorical variables, and the
@@ -164,6 +168,7 @@ preprocessTrainDataFrame <- function(input_df) {
164168
ordered_mask <- sapply(input_df, is.ordered)
165169
ordered_cat_matches <- factor_mask & ordered_mask
166170
ordered_cat_vars <- df_vars[ordered_cat_matches]
171+
ordered_cat_var_inds <- unname(which(ordered_cat_matches))
167172
num_ordered_cat_vars <- length(ordered_cat_vars)
168173
if (num_ordered_cat_vars > 0) ordered_cat_df <- input_df[,ordered_cat_vars,drop=F]
169174

@@ -173,12 +178,14 @@ preprocessTrainDataFrame <- function(input_df) {
173178
character_mask <- sapply(input_df, is.character)
174179
unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask
175180
unordered_cat_vars <- df_vars[unordered_cat_matches]
181+
unordered_cat_var_inds <- unname(which(unordered_cat_matches))
176182
num_unordered_cat_vars <- length(unordered_cat_vars)
177183
if (num_unordered_cat_vars > 0) unordered_cat_df <- input_df[,unordered_cat_vars,drop=F]
178184

179185
# Numeric variables
180186
numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches)
181187
numeric_vars <- df_vars[numeric_matches]
188+
numeric_var_inds <- unname(which(numeric_matches))
182189
num_numeric_vars <- length(numeric_vars)
183190
if (num_numeric_vars > 0) numeric_df <- input_df[,numeric_vars,drop=F]
184191

@@ -187,6 +194,7 @@ preprocessTrainDataFrame <- function(input_df) {
187194
unordered_unique_levels <- list()
188195
ordered_unique_levels <- list()
189196
feature_types <- integer(0)
197+
original_var_indices <- integer(0)
190198

191199
# First, extract the numeric covariates
192200
if (num_numeric_vars > 0) {
@@ -197,6 +205,7 @@ preprocessTrainDataFrame <- function(input_df) {
197205
}
198206
X <- cbind(X, unname(Xnum))
199207
feature_types <- c(feature_types, rep(0, ncol(Xnum)))
208+
original_var_indices <- c(original_var_indices, numeric_var_inds)
200209
}
201210

202211
# Next, run some simple preprocessing on the ordered categorical covariates
@@ -210,6 +219,7 @@ preprocessTrainDataFrame <- function(input_df) {
210219
}
211220
X <- cbind(X, unname(Xordcat))
212221
feature_types <- c(feature_types, rep(1, ncol(Xordcat)))
222+
original_var_indices <- c(original_var_indices, ordered_cat_var_inds)
213223
}
214224

215225
# Finally, one-hot encode the unordered categorical covariates
@@ -220,6 +230,8 @@ preprocessTrainDataFrame <- function(input_df) {
220230
encode_list <- oneHotInitializeAndEncode(unordered_cat_df[,i])
221231
unordered_unique_levels[[var_name]] <- encode_list$unique_levels
222232
one_hot_mats[[var_name]] <- encode_list$Xtilde
233+
one_hot_var <- rep(unordered_cat_var_inds[i], ncol(encode_list$Xtilde))
234+
original_var_indices <- c(original_var_indices, one_hot_var)
223235
}
224236
Xcat <- do.call(cbind, one_hot_mats)
225237
X <- cbind(X, unname(Xcat))
@@ -231,7 +243,8 @@ preprocessTrainDataFrame <- function(input_df) {
231243
feature_types = feature_types,
232244
num_ordered_cat_vars = num_ordered_cat_vars,
233245
num_unordered_cat_vars = num_unordered_cat_vars,
234-
num_numeric_vars = num_numeric_vars
246+
num_numeric_vars = num_numeric_vars,
247+
original_var_indices = original_var_indices
235248
)
236249
if (num_ordered_cat_vars > 0) {
237250
metadata[["ordered_cat_vars"]] = ordered_cat_vars

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pip install matplotlib seaborn jupyterlab
9393
The package can be installed in R via
9494

9595
```
96-
remotes::install_github("StochasticTree/stochtree-cpp")
96+
remotes::install_github("StochasticTree/stochtree-cpp", ref="r-dev")
9797
```
9898

9999
# C++ Core

include/stochtree/ensemble.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@ class TreeEnsemble {
8181

8282
inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
8383
int tree_begin, int tree_end, data_size_t offset = 0) {
84-
if (dataset.HasBasis()) {
85-
CHECK(!is_leaf_constant_);
86-
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
87-
} else {
88-
CHECK(is_leaf_constant_);
84+
if (is_leaf_constant_) {
8985
PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
86+
} else {
87+
CHECK(dataset.HasBasis());
88+
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
9089
}
9190
}
9291

include/stochtree/leaf_model.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ class GaussianConstantLeafModel {
7171
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
7272
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
7373
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
74-
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
74+
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
75+
std::vector<FeatureType>& feature_types);
7576
double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
7677
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
7778
double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
@@ -136,7 +137,8 @@ class GaussianUnivariateRegressionLeafModel {
136137
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
137138
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
138139
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
139-
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
140+
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
141+
std::vector<FeatureType>& feature_types);
140142
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance);
141143
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
142144
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
@@ -203,7 +205,8 @@ class GaussianMultivariateRegressionLeafModel {
203205
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
204206
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
205207
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
206-
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
208+
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
209+
std::vector<FeatureType>& feature_types);
207210
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance);
208211
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
209212
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);

0 commit comments

Comments
 (0)