Skip to content

Commit d0100f2

Browse files
committed
Implemented and tested higher-level "collapse" functionality for forest containers in R
1 parent c4b3f62 commit d0100f2

File tree

3 files changed

+205
-4
lines changed

3 files changed

+205
-4
lines changed

R/forest.R

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,42 @@ ForestSamples <- R6::R6Class(
2323
},
2424

2525
#' @description
26-
#' Collapse specified forests into a single forest
26+
#' Collapse forests in this container by a pre-specified batch size.
27+
#' For example, if we have a container of twenty 10-tree forests, and we
28+
#' specify a `batch_size` of 5, then this method will yield four 50-tree
29+
#' forests. "Excess" forests remaining after the size of a forest container
30+
#' is divided by `batch_size` will be pruned from the beginning of the
31+
#' container (i.e. earlier sampled forests will be deleted). This method
32+
#' has no effect if `batch_size` is larger than the number of forests
33+
#' in a container.
34+
#' @param batch_size Number of forests to be collapsed into a single forest
35+
collapse = function(batch_size) {
36+
container_size <- self$num_samples()
37+
if ((batch_size <= container_size) && (batch_size > 1)) {
38+
reverse_container_inds <- seq(container_size, 1, -1)
39+
num_clean_batches <- container_size %/% batch_size
40+
batch_inds <- (reverse_container_inds - (container_size - (container_size %/% num_clean_batches) * num_clean_batches) - 1) %/% batch_size
41+
for (batch_ind in unique(batch_inds[batch_inds >= 0])) {
42+
merge_forest_inds <- sort(reverse_container_inds[batch_inds == batch_ind] - 1)
43+
num_merge_forests <- length(merge_forest_inds)
44+
self$combine_forests(merge_forest_inds)
45+
for (i in num_merge_forests:2) {
46+
self$delete_sample(merge_forest_inds[i])
47+
}
48+
forest_scale_factor <- 1.0 / num_merge_forests
49+
self$multiply_forest(merge_forest_inds[1], forest_scale_factor)
50+
}
51+
if (min(batch_inds) < 0) {
52+
delete_forest_inds <- sort(reverse_container_inds[batch_inds < 0] - 1)
53+
for (i in length(delete_forest_inds):1) {
54+
self$delete_sample(delete_forest_inds[i])
55+
}
56+
}
57+
}
58+
},
59+
60+
#' @description
61+
#' Merge specified forests into a single forest
2762
#' @param forest_inds Indices of forests to be combined (0-indexed)
2863
combine_forests = function(forest_inds) {
2964
stopifnot(max(forest_inds) < self$num_samples())

man/ForestSamples.Rd

Lines changed: 28 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/R/testthat/test-forest-container.R

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,144 @@ test_that("Univariate constant forest container", {
108108
# Assertion
109109
expect_equal(pred, pred_expected_new)
110110
})
111+
112+
test_that("Collapse forests", {
113+
skip_on_cran()
114+
115+
# Generate simulated data
116+
n <- 100
117+
p <- 5
118+
X <- matrix(runif(n*p), ncol = p)
119+
f_XW <- (
120+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
121+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
122+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
123+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
124+
)
125+
noise_sd <- 1
126+
y <- f_XW + rnorm(n, 0, noise_sd)
127+
test_set_pct <- 0.2
128+
n_test <- round(test_set_pct*n)
129+
n_train <- n - n_test
130+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
131+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
132+
X_test <- X[test_inds,]
133+
X_train <- X[train_inds,]
134+
y_test <- y[test_inds]
135+
y_train <- y[train_inds]
136+
137+
# Create forest dataset
138+
forest_dataset_test <- createForestDataset(covariates = X_test)
139+
140+
# Run BART for 50 iterations
141+
num_mcmc <- 50
142+
general_param_list <- list(num_chains = 1, keep_every = 1)
143+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
144+
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
145+
general_params = general_param_list)
146+
147+
# Extract the mean forest container
148+
mean_forest_container <- bart_model$mean_forests
149+
150+
# Predict from the original container
151+
pred_orig <- mean_forest_container$predict(forest_dataset_test)
152+
153+
# Collapse the container in batches of 5
154+
batch_size <- 5
155+
mean_forest_container$collapse(batch_size)
156+
157+
# Predict from the modified container
158+
pred_new <- mean_forest_container$predict(forest_dataset_test)
159+
160+
# Check that corresponding (sums of) predictions match
161+
batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1
162+
pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds))
163+
for (i in 1:max(batch_inds)) {
164+
pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i)
165+
}
166+
167+
# Assertion
168+
expect_equal(pred_orig_collapsed, pred_new)
169+
170+
# Now run BART for 52 iterations
171+
num_mcmc <- 52
172+
general_param_list <- list(num_chains = 1, keep_every = 1)
173+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
174+
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
175+
general_params = general_param_list)
176+
177+
# Extract the mean forest container
178+
mean_forest_container <- bart_model$mean_forests
179+
180+
# Predict from the original container
181+
pred_orig <- mean_forest_container$predict(forest_dataset_test)
182+
183+
# Collapse the container in batches of 5
184+
batch_size <- 5
185+
mean_forest_container$collapse(batch_size)
186+
187+
# Predict from the modified container
188+
pred_new <- mean_forest_container$predict(forest_dataset_test)
189+
190+
# Check that corresponding (sums of) predictions match
191+
batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1
192+
pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds))
193+
for (i in 1:max(batch_inds)) {
194+
pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i)
195+
}
196+
197+
# Assertion
198+
expect_equal(pred_orig_collapsed, pred_new)
199+
200+
# Now run BART for 5 iterations
201+
num_mcmc <- 5
202+
general_param_list <- list(num_chains = 1, keep_every = 1)
203+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
204+
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
205+
general_params = general_param_list)
206+
207+
# Extract the mean forest container
208+
mean_forest_container <- bart_model$mean_forests
209+
210+
# Predict from the original container
211+
pred_orig <- mean_forest_container$predict(forest_dataset_test)
212+
213+
# Collapse the container in batches of 5
214+
batch_size <- 5
215+
mean_forest_container$collapse(batch_size)
216+
217+
# Predict from the modified container
218+
pred_new <- mean_forest_container$predict(forest_dataset_test)
219+
220+
# Check that corresponding (sums of) predictions match
221+
pred_orig_collapsed <- as.matrix(rowSums(pred_orig) / batch_size)
222+
223+
# Assertion
224+
expect_equal(pred_orig_collapsed, pred_new)
225+
226+
# Now run BART for 4 iterations
227+
num_mcmc <- 4
228+
general_param_list <- list(num_chains = 1, keep_every = 1)
229+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
230+
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc,
231+
general_params = general_param_list)
232+
233+
# Extract the mean forest container
234+
mean_forest_container <- bart_model$mean_forests
235+
236+
# Predict from the original container
237+
pred_orig <- mean_forest_container$predict(forest_dataset_test)
238+
239+
# Collapse the container in batches of 5
240+
batch_size <- 5
241+
mean_forest_container$collapse(batch_size)
242+
243+
# Predict from the modified container
244+
pred_new <- mean_forest_container$predict(forest_dataset_test)
245+
246+
# Check that corresponding (sums of) predictions match
247+
pred_orig_collapsed <- pred_orig
248+
249+
# Assertion
250+
expect_equal(pred_orig_collapsed, pred_new)
251+
})

0 commit comments

Comments
 (0)