Skip to content

Using subsets of features in the prognostic and treatment forests in BCF #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as 0.5/num_trees if not set here.
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#' @param num_trees Number of trees in the ensemble. Default: 200.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
Expand Down Expand Up @@ -80,10 +81,19 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95,
beta = 2.0, min_samples_leaf = 5, leaf_model = 0,
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
sample_tau = T, random_seed = -1, keep_burnin = F,
keep_gfr = F, verbose = F){
q = 0.9, sigma2_init = NULL, variable_weights = NULL,
num_trees = 200, num_gfr = 5, num_burnin = 0,
num_mcmc = 100, sample_sigma = T, sample_tau = T,
random_seed = -1, keep_burnin = F, keep_gfr = F,
verbose = F){
# Variable weight preprocessing (and initialization if necessary)
if (is.null(variable_weights)) {
variable_weights = rep(1/ncol(X_train), ncol(X_train))
}
if (any(variable_weights < 0)) {
stop("variable_weights cannot have any negative weights")
}

# Preprocess covariates
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
stop("X_train must be a matrix or dataframe")
Expand All @@ -93,12 +103,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
stop("X_test must be a matrix or dataframe")
}
}
if (ncol(X_train) != length(variable_weights)) {
stop("length(variable_weights) must equal ncol(X_train)")
}
train_cov_preprocess_list <- preprocessTrainData(X_train)
X_train_metadata <- train_cov_preprocess_list$metadata
X_train <- train_cov_preprocess_list$data
original_var_indices <- X_train_metadata$original_var_indices
feature_types <- X_train_metadata$feature_types
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)

# Update variable weights
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
variable_weights <- variable_weights[original_var_indices]*variable_weights_adj

# Convert all input data to matrices if not already converted
if ((is.null(dim(W_train))) && (!is.null(W_train))) {
W_train <- as.matrix(W_train)
Expand Down Expand Up @@ -295,9 +313,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
if (sample_sigma) global_var_samples <- rep(0, num_samples)
if (sample_tau) leaf_scale_samples <- rep(0, num_samples)

# Variable selection weights
variable_weights <- rep(1/ncol(X_train), ncol(X_train))

# Run GFR (warm start) if specified
if (num_gfr > 0){
gfr_indices = 1:num_gfr
Expand Down
219 changes: 147 additions & 72 deletions R/bcf.R

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' types. Matrices will be passed through assuming all columns are numeric.
#'
#' @param input_data Covariates, provided as either a dataframe or a matrix
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed (unmodified) data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand Down Expand Up @@ -63,6 +64,7 @@ preprocessPredictionData <- function(input_data, metadata) {
#' Returns a list including a matrix of preprocessed covariate values and associated tracking.
#'
#' @param input_matrix Covariate matrix.
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed (unmodified) data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand Down Expand Up @@ -97,7 +99,8 @@ preprocessTrainMatrix <- function(input_matrix) {
num_ordered_cat_vars = num_ordered_cat_vars,
num_unordered_cat_vars = num_unordered_cat_vars,
num_numeric_vars = num_numeric_vars,
numeric_vars = numeric_vars
numeric_vars = numeric_vars,
original_var_indices = 1:num_numeric_vars
)
output <- list(
data = X,
Expand Down Expand Up @@ -139,6 +142,7 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) {
#'
#' @param input_df Dataframe of covariates. Users must pre-process any
#' categorical variables as factors (ordered for ordered categorical).
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable
#'
#' @return List with preprocessed data and details on the number of each type
#' of variable, unique categories associated with categorical variables, and the
Expand All @@ -164,6 +168,7 @@ preprocessTrainDataFrame <- function(input_df) {
ordered_mask <- sapply(input_df, is.ordered)
ordered_cat_matches <- factor_mask & ordered_mask
ordered_cat_vars <- df_vars[ordered_cat_matches]
ordered_cat_var_inds <- unname(which(ordered_cat_matches))
num_ordered_cat_vars <- length(ordered_cat_vars)
if (num_ordered_cat_vars > 0) ordered_cat_df <- input_df[,ordered_cat_vars,drop=F]

Expand All @@ -173,12 +178,14 @@ preprocessTrainDataFrame <- function(input_df) {
character_mask <- sapply(input_df, is.character)
unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask
unordered_cat_vars <- df_vars[unordered_cat_matches]
unordered_cat_var_inds <- unname(which(unordered_cat_matches))
num_unordered_cat_vars <- length(unordered_cat_vars)
if (num_unordered_cat_vars > 0) unordered_cat_df <- input_df[,unordered_cat_vars,drop=F]

# Numeric variables
numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches)
numeric_vars <- df_vars[numeric_matches]
numeric_var_inds <- unname(which(numeric_matches))
num_numeric_vars <- length(numeric_vars)
if (num_numeric_vars > 0) numeric_df <- input_df[,numeric_vars,drop=F]

Expand All @@ -187,6 +194,7 @@ preprocessTrainDataFrame <- function(input_df) {
unordered_unique_levels <- list()
ordered_unique_levels <- list()
feature_types <- integer(0)
original_var_indices <- integer(0)

# First, extract the numeric covariates
if (num_numeric_vars > 0) {
Expand All @@ -197,6 +205,7 @@ preprocessTrainDataFrame <- function(input_df) {
}
X <- cbind(X, unname(Xnum))
feature_types <- c(feature_types, rep(0, ncol(Xnum)))
original_var_indices <- c(original_var_indices, numeric_var_inds)
}

# Next, run some simple preprocessing on the ordered categorical covariates
Expand All @@ -210,6 +219,7 @@ preprocessTrainDataFrame <- function(input_df) {
}
X <- cbind(X, unname(Xordcat))
feature_types <- c(feature_types, rep(1, ncol(Xordcat)))
original_var_indices <- c(original_var_indices, ordered_cat_var_inds)
}

# Finally, one-hot encode the unordered categorical covariates
Expand All @@ -220,6 +230,8 @@ preprocessTrainDataFrame <- function(input_df) {
encode_list <- oneHotInitializeAndEncode(unordered_cat_df[,i])
unordered_unique_levels[[var_name]] <- encode_list$unique_levels
one_hot_mats[[var_name]] <- encode_list$Xtilde
one_hot_var <- rep(unordered_cat_var_inds[i], ncol(encode_list$Xtilde))
original_var_indices <- c(original_var_indices, one_hot_var)
}
Xcat <- do.call(cbind, one_hot_mats)
X <- cbind(X, unname(Xcat))
Expand All @@ -231,7 +243,8 @@ preprocessTrainDataFrame <- function(input_df) {
feature_types = feature_types,
num_ordered_cat_vars = num_ordered_cat_vars,
num_unordered_cat_vars = num_unordered_cat_vars,
num_numeric_vars = num_numeric_vars
num_numeric_vars = num_numeric_vars,
original_var_indices = original_var_indices
)
if (num_ordered_cat_vars > 0) {
metadata[["ordered_cat_vars"]] = ordered_cat_vars
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pip install matplotlib seaborn jupyterlab
The package can be installed in R via

```
remotes::install_github("StochasticTree/stochtree-cpp")
remotes::install_github("StochasticTree/stochtree-cpp", ref="r-dev")
```

# C++ Core
Expand Down
9 changes: 4 additions & 5 deletions include/stochtree/ensemble.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ class TreeEnsemble {

inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
int tree_begin, int tree_end, data_size_t offset = 0) {
if (dataset.HasBasis()) {
CHECK(!is_leaf_constant_);
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
} else {
CHECK(is_leaf_constant_);
if (is_leaf_constant_) {
PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
} else {
CHECK(dataset.HasBasis());
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
}
}

Expand Down
9 changes: 6 additions & 3 deletions include/stochtree/leaf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class GaussianConstantLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance);
double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance);
Expand Down Expand Up @@ -136,7 +137,8 @@ class GaussianUnivariateRegressionLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance);
Expand Down Expand Up @@ -203,7 +205,8 @@ class GaussianMultivariateRegressionLeafModel {
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
void EvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<FeatureType>& feature_types);
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types);
double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance);
double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance);
Expand Down
Loading
Loading