Skip to content

Refactor model parameters into "config" objects to future-proof low-level interface #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/r-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:

- name: Create a CRAN-ready version of the R package
run: |
Rscript cran-bootstrap.R 0 0
Rscript cran-bootstrap.R 0 0 1

- uses: r-lib/actions/check-r-package@v2
with:
Expand Down
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: stochtree
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
Title: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
Version: 0.1.0
Authors@R:
c(
Expand All @@ -10,7 +10,11 @@ Authors@R:
person("Jingyu", "He", role = "aut"),
person("stochtree contributors", role = c("cph"))
)
Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference.
Description: Flexible stochastic tree ensemble software. Robust implementations of
Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>)
for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) <doi:10.1214/19-BA1195>)
for causal inference. Enables model serialization and parallel sampling
and provides a low-level interface for custom stochastic forest samplers.
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
Expand Down
4 changes: 2 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
YEAR: 2024
COPYRIGHT HOLDER: stochtree authors
YEAR: 2025
COPYRIGHT HOLDER: stochtree contributors
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License

Copyright (c) 2024 stochtree authors
Copyright (c) 2023-2025 stochtree authors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ export(createCppRNG)
export(createForest)
export(createForestDataset)
export(createForestModel)
export(createForestModelConfig)
export(createForestSamples)
export(createGlobalModelConfig)
export(createOutcome)
export(createPreprocessorFromJson)
export(createPreprocessorFromJsonString)
Expand Down
15 changes: 14 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# stochtree 0.1.0

* Initial CRAN submission.
* Initial release on CRAN.
* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR)
* High-level model types supported:
* Supervised learning with constant leaves or user-specified leaf regression models
* Causal effect estimation with binary, continuous, or multivariate treatments
* Additional high-level modeling features:
* Forest-based variance function estimation (heteroskedasticity)
* Additive (univariate or multivariate) group random effects
* Multi-chain sampling and support for parallelism
* "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm
* Automated preprocessing / handling of categorical variables
* Low-level interface:
* Ability to combine a forest sampler with other (additive) model terms, without using C++
* Combine and sample an arbitrary number of forests or random effects terms
92 changes: 67 additions & 25 deletions R/bart.R

Large diffs are not rendered by default.

188 changes: 135 additions & 53 deletions R/bcf.R

Large diffs are not rendered by default.

395 changes: 395 additions & 0 deletions R/config.R

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
}

sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
Expand Down
49 changes: 44 additions & 5 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ Forest <- R6::R6Class(
#' @field forest_ptr External pointer to a C++ TreeEnsemble class
forest_ptr = NULL,

#' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called.
internal_forest_is_empty = TRUE,

#' @description
#' Create a new Forest object.
#' @param num_trees Number of trees in the forest
Expand All @@ -567,6 +570,7 @@ Forest <- R6::R6Class(
#' @return A new `Forest` object.
initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
self$forest_ptr <- active_forest_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
self$internal_forest_is_empty <- TRUE
},

#' @description
Expand Down Expand Up @@ -610,6 +614,7 @@ Forest <- R6::R6Class(
#' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension.
set_root_leaves = function(leaf_value) {
stopifnot(!is.null(self$forest_ptr))
stopifnot(self$internal_forest_is_empty)

# Set leaf values
if (length(leaf_value) == 1) {
Expand All @@ -621,6 +626,8 @@ Forest <- R6::R6Class(
} else {
stop("leaf_value must be a numeric value or vector of length >= 1")
}

self$internal_forest_is_empty = FALSE
},

#' @description
Expand All @@ -636,12 +643,15 @@ Forest <- R6::R6Class(
stopifnot(!is.null(outcome$data_ptr))
stopifnot(!is.null(forest_model$tracker_ptr))
stopifnot(!is.null(self$forest_ptr))
stopifnot(self$internal_forest_is_empty)

# Initialize the model
initialize_forest_model_active_forest_cpp(
dataset$data_ptr, outcome$data_ptr, self$forest_ptr,
forest_model$tracker_ptr, leaf_value, leaf_model_int
)

self$internal_forest_is_empty = FALSE
},

#' @description
Expand Down Expand Up @@ -745,6 +755,23 @@ Forest <- R6::R6Class(
#' @return Average maximum depth
average_max_depth = function() {
return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr))
},

#' @description
#' When a forest object is created, it is "empty" in the sense that none
#' of its component trees have leaves with values. There are two ways to
#' "initialize" a Forest object. First, the `set_root_leaves()` method
#' simply initializes every tree in the forest to a single node carrying
#' the same (user-specified) leaf value. Second, the `prepare_for_sampler()`
#' method initializes every tree in the forest to a single node with the
#' same value and also propagates this information through to a ForestModel
#' object, which must be synchronized with a Forest during a forest
#' sampler loop.
#' @return `TRUE` if a Forest has not yet been initialized with a constant
#' root value, `FALSE` otherwise if the forest has already been
#' initialized / grown.
is_empty = function() {
return(self$internal_forest_is_empty)
}
)
)
Expand Down Expand Up @@ -818,6 +845,7 @@ createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exp
resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NULL) {
if (is.null(forest_samples)) {
root_reset_active_forest_cpp(active_forest$forest_ptr)
active_forest$internal_forest_is_empty = TRUE
} else {
if (is.null(forest_num)) {
stop("`forest_num` must be specified if `forest_samples` is provided")
Expand Down Expand Up @@ -860,14 +888,25 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(n)
#' outcome <- createOutcome(y)
#' rng <- createCppRNG(1234)
#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth)
#' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2)
#' forest_model_config <- createForestModelConfig(feature_types=feature_types,
#' num_trees=num_trees, num_observations=n,
#' num_features=p, alpha=alpha, beta=beta,
#' min_samples_leaf=min_samples_leaf,
#' max_depth=max_depth,
#' variable_weights=variable_weights,
#' cutpoint_grid_size=cutpoint_grid_size,
#' leaf_model_type=leaf_model,
#' leaf_model_scale=leaf_scale)
#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
#' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
#' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
#' forest_samples <- createForestSamples(num_trees, leaf_dimension,
#' is_leaf_constant, is_exponentiated)
#' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.)
#' forest_model$sample_one_iteration(
#' forest_dataset, outcome, forest_samples, active_forest,
#' rng, feature_types, leaf_model, leaf_scale, variable_weights,
#' a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE,
#' gfr = FALSE, pre_initialized = TRUE
#' rng, forest_model_config, global_model_config,
#' keep_forest = TRUE, gfr = FALSE
#' )
#' resetActiveForest(active_forest, forest_samples, 0)
#' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)
Expand Down
14 changes: 6 additions & 8 deletions R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
# Extract relevant forest container
object_name <- class(model_object)[1]
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
if (model_type == "bart") {
stopifnot(forest_type %in% c("mean", "variance"))
if (forest_type=="mean") {
Expand Down Expand Up @@ -143,8 +142,8 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
#' computeForestLeafVariances(bart_model, "mean", c(1,3,5))
computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) {
# Extract relevant forest container
stopifnot(class(model_object) %in% c("bartmodel", "bcfmodel"))
model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf")
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"))))
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf")
if (model_type == "bart") {
stopifnot(forest_type %in% c("mean", "variance"))
if (forest_type=="mean") {
Expand Down Expand Up @@ -234,9 +233,8 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
# Extract relevant forest container
object_name <- class(model_object)[1]
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
if (model_type == "bart") {
stopifnot(forest_type %in% c("mean", "variance"))
if (forest_type=="mean") {
Expand Down
Loading
Loading