@@ -108,3 +108,144 @@ test_that("Univariate constant forest container", {
108
108
# Assertion
109
109
expect_equal(pred , pred_expected_new )
110
110
})
111
+
112
+ test_that(" Collapse forests" , {
113
+ skip_on_cran()
114
+
115
+ # Generate simulated data
116
+ n <- 100
117
+ p <- 5
118
+ X <- matrix (runif(n * p ), ncol = p )
119
+ f_XW <- (
120
+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 ) +
121
+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 ) +
122
+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 ) +
123
+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 )
124
+ )
125
+ noise_sd <- 1
126
+ y <- f_XW + rnorm(n , 0 , noise_sd )
127
+ test_set_pct <- 0.2
128
+ n_test <- round(test_set_pct * n )
129
+ n_train <- n - n_test
130
+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
131
+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
132
+ X_test <- X [test_inds ,]
133
+ X_train <- X [train_inds ,]
134
+ y_test <- y [test_inds ]
135
+ y_train <- y [train_inds ]
136
+
137
+ # Create forest dataset
138
+ forest_dataset_test <- createForestDataset(covariates = X_test )
139
+
140
+ # Run BART for 50 iterations
141
+ num_mcmc <- 50
142
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
143
+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
144
+ num_gfr = 0 , num_burnin = 0 , num_mcmc = num_mcmc ,
145
+ general_params = general_param_list )
146
+
147
+ # Extract the mean forest container
148
+ mean_forest_container <- bart_model $ mean_forests
149
+
150
+ # Predict from the original container
151
+ pred_orig <- mean_forest_container $ predict(forest_dataset_test )
152
+
153
+ # Collapse the container in batches of 5
154
+ batch_size <- 5
155
+ mean_forest_container $ collapse(batch_size )
156
+
157
+ # Predict from the modified container
158
+ pred_new <- mean_forest_container $ predict(forest_dataset_test )
159
+
160
+ # Check that corresponding (sums of) predictions match
161
+ batch_inds <- (seq(1 ,num_mcmc ,1 ) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size )) * (num_mcmc %/% batch_size )) - 1 ) %/% batch_size + 1
162
+ pred_orig_collapsed <- matrix (NA , nrow = nrow(pred_orig ), ncol = max(batch_inds ))
163
+ for (i in 1 : max(batch_inds )) {
164
+ pred_orig_collapsed [,i ] <- rowSums(pred_orig [,batch_inds == i ]) / sum(batch_inds == i )
165
+ }
166
+
167
+ # Assertion
168
+ expect_equal(pred_orig_collapsed , pred_new )
169
+
170
+ # Now run BART for 52 iterations
171
+ num_mcmc <- 52
172
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
173
+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
174
+ num_gfr = 0 , num_burnin = 0 , num_mcmc = num_mcmc ,
175
+ general_params = general_param_list )
176
+
177
+ # Extract the mean forest container
178
+ mean_forest_container <- bart_model $ mean_forests
179
+
180
+ # Predict from the original container
181
+ pred_orig <- mean_forest_container $ predict(forest_dataset_test )
182
+
183
+ # Collapse the container in batches of 5
184
+ batch_size <- 5
185
+ mean_forest_container $ collapse(batch_size )
186
+
187
+ # Predict from the modified container
188
+ pred_new <- mean_forest_container $ predict(forest_dataset_test )
189
+
190
+ # Check that corresponding (sums of) predictions match
191
+ batch_inds <- (seq(1 ,num_mcmc ,1 ) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size )) * (num_mcmc %/% batch_size )) - 1 ) %/% batch_size + 1
192
+ pred_orig_collapsed <- matrix (NA , nrow = nrow(pred_orig ), ncol = max(batch_inds ))
193
+ for (i in 1 : max(batch_inds )) {
194
+ pred_orig_collapsed [,i ] <- rowSums(pred_orig [,batch_inds == i ]) / sum(batch_inds == i )
195
+ }
196
+
197
+ # Assertion
198
+ expect_equal(pred_orig_collapsed , pred_new )
199
+
200
+ # Now run BART for 5 iterations
201
+ num_mcmc <- 5
202
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
203
+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
204
+ num_gfr = 0 , num_burnin = 0 , num_mcmc = num_mcmc ,
205
+ general_params = general_param_list )
206
+
207
+ # Extract the mean forest container
208
+ mean_forest_container <- bart_model $ mean_forests
209
+
210
+ # Predict from the original container
211
+ pred_orig <- mean_forest_container $ predict(forest_dataset_test )
212
+
213
+ # Collapse the container in batches of 5
214
+ batch_size <- 5
215
+ mean_forest_container $ collapse(batch_size )
216
+
217
+ # Predict from the modified container
218
+ pred_new <- mean_forest_container $ predict(forest_dataset_test )
219
+
220
+ # Check that corresponding (sums of) predictions match
221
+ pred_orig_collapsed <- as.matrix(rowSums(pred_orig ) / batch_size )
222
+
223
+ # Assertion
224
+ expect_equal(pred_orig_collapsed , pred_new )
225
+
226
+ # Now run BART for 4 iterations
227
+ num_mcmc <- 4
228
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
229
+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
230
+ num_gfr = 0 , num_burnin = 0 , num_mcmc = num_mcmc ,
231
+ general_params = general_param_list )
232
+
233
+ # Extract the mean forest container
234
+ mean_forest_container <- bart_model $ mean_forests
235
+
236
+ # Predict from the original container
237
+ pred_orig <- mean_forest_container $ predict(forest_dataset_test )
238
+
239
+ # Collapse the container in batches of 5
240
+ batch_size <- 5
241
+ mean_forest_container $ collapse(batch_size )
242
+
243
+ # Predict from the modified container
244
+ pred_new <- mean_forest_container $ predict(forest_dataset_test )
245
+
246
+ # Check that corresponding (sums of) predictions match
247
+ pred_orig_collapsed <- pred_orig
248
+
249
+ # Assertion
250
+ expect_equal(pred_orig_collapsed , pred_new )
251
+ })
0 commit comments