Skip to content

Commit d746d92

Browse files
committed
Updated ensemble predict method to skip over a dataset basis for prediction if not used in a model
1 parent 008a5dd commit d746d92

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

R/bcf.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
770770
"outcome_mean" = y_bar_train,
771771
"outcome_scale" = y_std_train,
772772
"num_covariates" = ncol(X_train),
773-
"num_prognostic_covariates" = ncol(X_train_mu),
774-
"num_treatment_covariates" = ncol(X_train_tau),
773+
"num_prognostic_covariates" = sum(variable_weights_mu > 0),
774+
"num_treatment_covariates" = sum(variable_weights_tau > 0),
775775
"treatment_dim" = ncol(Z_train),
776776
"propensity_covariate" = propensity_covariate,
777777
"binary_treatment" = binary_treatment,

include/stochtree/ensemble.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@ class TreeEnsemble {
8181

8282
inline void PredictInplace(ForestDataset& dataset, std::vector<double> &output,
8383
int tree_begin, int tree_end, data_size_t offset = 0) {
84-
if (dataset.HasBasis()) {
85-
CHECK(!is_leaf_constant_);
86-
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
87-
} else {
88-
CHECK(is_leaf_constant_);
84+
if (is_leaf_constant_) {
8985
PredictInplace(dataset.GetCovariates(), output, tree_begin, tree_end, offset);
86+
} else {
87+
CHECK(dataset.HasBasis());
88+
PredictInplace(dataset.GetCovariates(), dataset.GetBasis(), output, tree_begin, tree_end, offset);
9089
}
9190
}
9291

0 commit comments

Comments
 (0)