You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: R/bart.R
+23-8Lines changed: 23 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -33,9 +33,10 @@
33
33
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
34
34
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
35
35
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
36
-
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as 0.5/num_trees if not set here.
36
+
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
37
37
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
38
38
#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
39
+
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
39
40
#' @param num_trees Number of trees in the ensemble. Default: 200.
40
41
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
41
42
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
Copy file name to clipboardExpand all lines: include/stochtree/leaf_model.h
+6-3Lines changed: 6 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -71,7 +71,8 @@ class GaussianConstantLeafModel {
71
71
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
72
72
voidEvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
@@ -136,7 +137,8 @@ class GaussianUnivariateRegressionLeafModel {
136
137
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
137
138
voidEvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
@@ -203,7 +205,8 @@ class GaussianMultivariateRegressionLeafModel {
203
205
std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
204
206
voidEvaluateAllPossibleSplits(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, double global_variance, int tree_num, int split_node_id,
0 commit comments