Skip to content

Commit a22ca1c

Browse files
committed
Added same collapse functionality and tests to python interface
1 parent d0100f2 commit a22ca1c

File tree

2 files changed

+216
-2
lines changed

2 files changed

+216
-2
lines changed

stochtree/forest.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,42 @@ def set_root_leaves(
159159
else:
160160
self.forest_container_cpp.SetRootValue(forest_num, leaf_value)
161161

162+
def collapse(self, batch_size: int) -> None:
163+
"""
164+
Collapse forests in this container by a pre-specified batch size.
165+
For example, if we have a container of twenty 10-tree forests, and we
166+
specify a `batch_size` of 5, then this method will yield four 50-tree
167+
forests. "Excess" forests remaining after the size of a forest container
168+
is divided by `batch_size` will be pruned from the beginning of the
169+
container (i.e. earlier sampled forests will be deleted). This method
170+
has no effect if `batch_size` is larger than the number of forests
171+
in a container.
172+
173+
Parameters
174+
----------
175+
batch_size : int
176+
Number of forests to be collapsed into a single forest
177+
"""
178+
container_size = self.num_samples()
179+
if batch_size <= container_size and batch_size > 1:
180+
reverse_container_inds = np.linspace(start=container_size, stop=1, num=container_size, dtype=int)
181+
num_clean_batches = container_size // batch_size
182+
batch_inds = (reverse_container_inds - (container_size - ((container_size // num_clean_batches) * num_clean_batches)) - 1) // batch_size
183+
batch_inds = batch_inds.astype(int)
184+
for batch_ind in np.flip(np.unique(batch_inds[batch_inds >= 0])):
185+
merge_forest_inds = np.sort(reverse_container_inds[batch_inds == batch_ind] - 1)
186+
num_merge_forests = len(merge_forest_inds)
187+
self.combine_forests(merge_forest_inds)
188+
for i in range(num_merge_forests - 1, 0, -1):
189+
self.delete_sample(merge_forest_inds[i])
190+
forest_scale_factor = 1.0 / num_merge_forests
191+
self.multiply_forest(merge_forest_inds[0], forest_scale_factor)
192+
if np.min(batch_inds) < 0:
193+
delete_forest_inds = np.sort(reverse_container_inds[batch_inds < 0] - 1)
194+
num_delete_forests = len(delete_forest_inds)
195+
for i in range(num_delete_forests - 1, -1, -1):
196+
self.delete_sample(delete_forest_inds[i])
197+
162198
def combine_forests(
163199
self, forest_inds: np.array
164200
) -> None:

test/python/test_forest_container.py

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
22

3-
from stochtree import Dataset, ForestContainer
3+
from stochtree import Dataset, ForestContainer, BARTModel
4+
from sklearn.model_selection import train_test_split
45

56

67
class TestPredict:
7-
def test_constant_leaf_prediction(self):
8+
def test_constant_leaf_forest_container(self):
89
# Create dataset
910
X = np.array(
1011
[[1.5, 8.7, 1.2],
@@ -75,3 +76,180 @@ def test_constant_leaf_prediction(self):
7576

7677
# Assertion
7778
np.testing.assert_almost_equal(pred, pred_expected_new)
79+
80+
def test_collapse_forest_container(self):
81+
# RNG
82+
rng = np.random.default_rng()
83+
84+
# Generate covariates and basis
85+
n = 100
86+
p_X = 10
87+
X = rng.uniform(0, 1, (n, p_X))
88+
89+
# Define the outcome mean function
90+
def outcome_mean(X):
91+
return np.where(
92+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
93+
-7.5,
94+
np.where(
95+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
96+
-2.5,
97+
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5, 7.5),
98+
),
99+
)
100+
101+
# Generate outcome
102+
epsilon = rng.normal(0, 1, n)
103+
y = outcome_mean(X) + epsilon
104+
105+
# Test-train split
106+
sample_inds = np.arange(n)
107+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
108+
X_train = X[train_inds, :]
109+
X_test = X[test_inds, :]
110+
y_train = y[train_inds]
111+
# y_test = y[test_inds]
112+
n_train = X_train.shape[0]
113+
n_test = X_test.shape[0]
114+
115+
# Create forest dataset
116+
forest_dataset_test = Dataset()
117+
forest_dataset_test.add_covariates(X_test)
118+
119+
# Run BART with 50 MCMC
120+
num_mcmc = 50
121+
bart_model = BARTModel()
122+
bart_model.sample(
123+
X_train=X_train,
124+
y_train=y_train,
125+
X_test=X_test,
126+
num_gfr=0,
127+
num_burnin=0,
128+
num_mcmc=num_mcmc,
129+
)
130+
131+
# Extract the mean forest container
132+
mean_forest_container = bart_model.forest_container_mean
133+
134+
# Predict from the original container
135+
pred_orig = mean_forest_container.predict(forest_dataset_test)
136+
137+
# Collapse the container in batches of 5
138+
batch_size = 5
139+
mean_forest_container.collapse(batch_size)
140+
141+
# Predict from the modified container
142+
pred_new = mean_forest_container.predict(forest_dataset_test)
143+
144+
# Check that corresponding (sums of) predictions match
145+
container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc)
146+
batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size
147+
batch_inds = batch_inds.astype(int)
148+
num_batches = np.max(batch_inds) + 1
149+
pred_orig_collapsed = np.empty((n_test, num_batches))
150+
for i in range(num_batches):
151+
pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i)
152+
153+
# Assertion
154+
np.testing.assert_almost_equal(pred_orig_collapsed, pred_new)
155+
156+
# Run BART with 52 MCMC
157+
num_mcmc = 52
158+
bart_model = BARTModel()
159+
bart_model.sample(
160+
X_train=X_train,
161+
y_train=y_train,
162+
X_test=X_test,
163+
num_gfr=0,
164+
num_burnin=0,
165+
num_mcmc=num_mcmc,
166+
)
167+
168+
# Extract the mean forest container
169+
mean_forest_container = bart_model.forest_container_mean
170+
171+
# Predict from the original container
172+
pred_orig = mean_forest_container.predict(forest_dataset_test)
173+
174+
# Collapse the container in batches of 5
175+
batch_size = 5
176+
mean_forest_container.collapse(batch_size)
177+
178+
# Predict from the modified container
179+
pred_new = mean_forest_container.predict(forest_dataset_test)
180+
181+
# Check that corresponding (sums of) predictions match
182+
container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc)
183+
batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size
184+
batch_inds = batch_inds.astype(int)
185+
num_batches = np.max(batch_inds) + 1
186+
pred_orig_collapsed = np.empty((n_test, num_batches))
187+
for i in range(num_batches):
188+
pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i)
189+
190+
# Assertion
191+
np.testing.assert_almost_equal(pred_orig_collapsed, pred_new)
192+
193+
# Run BART with 5 MCMC
194+
num_mcmc = 5
195+
bart_model = BARTModel()
196+
bart_model.sample(
197+
X_train=X_train,
198+
y_train=y_train,
199+
X_test=X_test,
200+
num_gfr=0,
201+
num_burnin=0,
202+
num_mcmc=num_mcmc,
203+
)
204+
205+
# Extract the mean forest container
206+
mean_forest_container = bart_model.forest_container_mean
207+
208+
# Predict from the original container
209+
pred_orig = mean_forest_container.predict(forest_dataset_test)
210+
211+
# Collapse the container in batches of 5
212+
batch_size = 5
213+
mean_forest_container.collapse(batch_size)
214+
215+
# Predict from the modified container
216+
pred_new = mean_forest_container.predict(forest_dataset_test)
217+
218+
# Check that corresponding (sums of) predictions match
219+
num_batches = 1
220+
pred_orig_collapsed = np.empty((n_test, num_batches))
221+
pred_orig_collapsed[:,0] = np.sum(pred_orig, axis=1) / batch_size
222+
223+
# Assertion
224+
np.testing.assert_almost_equal(pred_orig_collapsed, pred_new)
225+
226+
# Run BART with 4 MCMC
227+
num_mcmc = 4
228+
bart_model = BARTModel()
229+
bart_model.sample(
230+
X_train=X_train,
231+
y_train=y_train,
232+
X_test=X_test,
233+
num_gfr=0,
234+
num_burnin=0,
235+
num_mcmc=num_mcmc,
236+
)
237+
238+
# Extract the mean forest container
239+
mean_forest_container = bart_model.forest_container_mean
240+
241+
# Predict from the original container
242+
pred_orig = mean_forest_container.predict(forest_dataset_test)
243+
244+
# Collapse the container in batches of 5
245+
batch_size = 5
246+
mean_forest_container.collapse(batch_size)
247+
248+
# Predict from the modified container
249+
pred_new = mean_forest_container.predict(forest_dataset_test)
250+
251+
# Check that corresponding (sums of) predictions match
252+
pred_orig_collapsed = pred_orig
253+
254+
# Assertion
255+
np.testing.assert_almost_equal(pred_orig_collapsed, pred_new)

0 commit comments

Comments
 (0)