Skip to content

Commit 744ca40

Browse files
authored
Merge pull request #135 from StochasticTree/model-config-refactor
Refactor model parameters into "config" objects to future-proof low-level interface
2 parents 0a9352f + 1e73f92 commit 744ca40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2121
-352
lines changed

.github/workflows/r-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939

4040
- name: Create a CRAN-ready version of the R package
4141
run: |
42-
Rscript cran-bootstrap.R 0 0
42+
Rscript cran-bootstrap.R 0 0 1
4343
4444
- uses: r-lib/actions/check-r-package@v2
4545
with:

DESCRIPTION

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: stochtree
2-
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
2+
Title: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
33
Version: 0.1.0
44
Authors@R:
55
c(
@@ -10,7 +10,11 @@ Authors@R:
1010
person("Jingyu", "He", role = "aut"),
1111
person("stochtree contributors", role = c("cph"))
1212
)
13-
Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference.
13+
Description: Flexible stochastic tree ensemble software. Robust implementations of
14+
Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>)
15+
for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) <doi:10.1214/19-BA1195>)
16+
for causal inference. Enables model serialization and parallel sampling
17+
and provides a low-level interface for custom stochastic forest samplers.
1418
License: MIT + file LICENSE
1519
Encoding: UTF-8
1620
Roxygen: list(markdown = TRUE)

LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
YEAR: 2024
2-
COPYRIGHT HOLDER: stochtree authors
1+
YEAR: 2025
2+
COPYRIGHT HOLDER: stochtree contributors

LICENSE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22

3-
Copyright (c) 2024 stochtree authors
3+
Copyright (c) 2023-2025 stochtree authors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ export(createCppRNG)
2828
export(createForest)
2929
export(createForestDataset)
3030
export(createForestModel)
31+
export(createForestModelConfig)
3132
export(createForestSamples)
33+
export(createGlobalModelConfig)
3234
export(createOutcome)
3335
export(createPreprocessorFromJson)
3436
export(createPreprocessorFromJsonString)

NEWS.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
11
# stochtree 0.1.0
22

3-
* Initial CRAN submission.
3+
* Initial release on CRAN.
4+
* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR)
5+
* High-level model types supported:
6+
* Supervised learning with constant leaves or user-specified leaf regression models
7+
* Causal effect estimation with binary, continuous, or multivariate treatments
8+
* Additional high-level modeling features:
9+
* Forest-based variance function estimation (heteroskedasticity)
10+
* Additive (univariate or multivariate) group random effects
11+
* Multi-chain sampling and support for parallelism
12+
* "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm
13+
* Automated preprocessing / handling of categorical variables
14+
* Low-level interface:
15+
* Ability to combine a forest sampler with other (additive) model terms, without using C++
16+
* Combine and sample an arbitrary number of forests or random effects terms

R/bart.R

Lines changed: 67 additions & 25 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 135 additions & 53 deletions
Large diffs are not rendered by default.

R/config.R

Lines changed: 395 additions & 0 deletions
Large diffs are not rendered by default.

R/cpp11.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
556556
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
557557
}
558558

559-
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
560-
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
559+
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
560+
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
561561
}
562562

563-
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
564-
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
563+
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
564+
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
565565
}
566566

567567
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {

R/forest.R

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,9 @@ Forest <- R6::R6Class(
558558
#' @field forest_ptr External pointer to a C++ TreeEnsemble class
559559
forest_ptr = NULL,
560560

561+
#' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called.
562+
internal_forest_is_empty = TRUE,
563+
561564
#' @description
562565
#' Create a new Forest object.
563566
#' @param num_trees Number of trees in the forest
@@ -567,6 +570,7 @@ Forest <- R6::R6Class(
567570
#' @return A new `Forest` object.
568571
initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
569572
self$forest_ptr <- active_forest_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
573+
self$internal_forest_is_empty <- TRUE
570574
},
571575

572576
#' @description
@@ -610,6 +614,7 @@ Forest <- R6::R6Class(
610614
#' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension.
611615
set_root_leaves = function(leaf_value) {
612616
stopifnot(!is.null(self$forest_ptr))
617+
stopifnot(self$internal_forest_is_empty)
613618

614619
# Set leaf values
615620
if (length(leaf_value) == 1) {
@@ -621,6 +626,8 @@ Forest <- R6::R6Class(
621626
} else {
622627
stop("leaf_value must be a numeric value or vector of length >= 1")
623628
}
629+
630+
self$internal_forest_is_empty = FALSE
624631
},
625632

626633
#' @description
@@ -636,12 +643,15 @@ Forest <- R6::R6Class(
636643
stopifnot(!is.null(outcome$data_ptr))
637644
stopifnot(!is.null(forest_model$tracker_ptr))
638645
stopifnot(!is.null(self$forest_ptr))
646+
stopifnot(self$internal_forest_is_empty)
639647

640648
# Initialize the model
641649
initialize_forest_model_active_forest_cpp(
642650
dataset$data_ptr, outcome$data_ptr, self$forest_ptr,
643651
forest_model$tracker_ptr, leaf_value, leaf_model_int
644652
)
653+
654+
self$internal_forest_is_empty = FALSE
645655
},
646656

647657
#' @description
@@ -745,6 +755,23 @@ Forest <- R6::R6Class(
745755
#' @return Average maximum depth
746756
average_max_depth = function() {
747757
return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr))
758+
},
759+
760+
#' @description
761+
#' When a forest object is created, it is "empty" in the sense that none
762+
#' of its component trees have leaves with values. There are two ways to
763+
#' "initialize" a Forest object. First, the `set_root_leaves()` method
764+
#' simply initializes every tree in the forest to a single node carrying
765+
#' the same (user-specified) leaf value. Second, the `prepare_for_sampler()`
766+
#' method initializes every tree in the forest to a single node with the
767+
#' same value and also propagates this information through to a ForestModel
768+
#' object, which must be synchronized with a Forest during a forest
769+
#' sampler loop.
770+
#' @return `TRUE` if a Forest has not yet been initialized with a constant
771+
#' root value, `FALSE` otherwise if the forest has already been
772+
#' initialized / grown.
773+
is_empty = function() {
774+
return(self$internal_forest_is_empty)
748775
}
749776
)
750777
)
@@ -818,6 +845,7 @@ createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exp
818845
resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NULL) {
819846
if (is.null(forest_samples)) {
820847
root_reset_active_forest_cpp(active_forest$forest_ptr)
848+
active_forest$internal_forest_is_empty = TRUE
821849
} else {
822850
if (is.null(forest_num)) {
823851
stop("`forest_num` must be specified if `forest_samples` is provided")
@@ -860,14 +888,25 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL
860888
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(n)
861889
#' outcome <- createOutcome(y)
862890
#' rng <- createCppRNG(1234)
863-
#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth)
891+
#' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2)
892+
#' forest_model_config <- createForestModelConfig(feature_types=feature_types,
893+
#' num_trees=num_trees, num_observations=n,
894+
#' num_features=p, alpha=alpha, beta=beta,
895+
#' min_samples_leaf=min_samples_leaf,
896+
#' max_depth=max_depth,
897+
#' variable_weights=variable_weights,
898+
#' cutpoint_grid_size=cutpoint_grid_size,
899+
#' leaf_model_type=leaf_model,
900+
#' leaf_model_scale=leaf_scale)
901+
#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
864902
#' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
865-
#' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
903+
#' forest_samples <- createForestSamples(num_trees, leaf_dimension,
904+
#' is_leaf_constant, is_exponentiated)
905+
#' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.)
866906
#' forest_model$sample_one_iteration(
867907
#' forest_dataset, outcome, forest_samples, active_forest,
868-
#' rng, feature_types, leaf_model, leaf_scale, variable_weights,
869-
#' a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE,
870-
#' gfr = FALSE, pre_initialized = TRUE
908+
#' rng, forest_model_config, global_model_config,
909+
#' keep_forest = TRUE, gfr = FALSE
871910
#' )
872911
#' resetActiveForest(active_forest, forest_samples, 0)
873912
#' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)

R/kernel.R

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@
4848
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
4949
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
5050
# Extract relevant forest container
51-
object_name <- class(model_object)[1]
52-
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
53-
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
51+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
52+
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
5453
if (model_type == "bart") {
5554
stopifnot(forest_type %in% c("mean", "variance"))
5655
if (forest_type=="mean") {
@@ -143,8 +142,8 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
143142
#' computeForestLeafVariances(bart_model, "mean", c(1,3,5))
144143
computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) {
145144
# Extract relevant forest container
146-
stopifnot(class(model_object) %in% c("bartmodel", "bcfmodel"))
147-
model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf")
145+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"))))
146+
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf")
148147
if (model_type == "bart") {
149148
stopifnot(forest_type %in% c("mean", "variance"))
150149
if (forest_type=="mean") {
@@ -234,9 +233,8 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
234233
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
235234
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
236235
# Extract relevant forest container
237-
object_name <- class(model_object)[1]
238-
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
239-
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
236+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
237+
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
240238
if (model_type == "bart") {
241239
stopifnot(forest_type %in% c("mean", "variance"))
242240
if (forest_type=="mean") {

0 commit comments

Comments
 (0)