Skip to content

Commit ff1bfe0

Browse files
authored
Merge pull request #157 from StochasticTree/python-update-0.1.1
Prepare next PyPI release
2 parents 848f61c + daa9243 commit ff1bfe0

File tree

11 files changed

+1750
-182
lines changed

11 files changed

+1750
-182
lines changed

Doxyfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree"
4848
# could be handy for archiving the generated documentation or if some version
4949
# control system is used.
5050

51-
PROJECT_NUMBER = 0.0.1
51+
PROJECT_NUMBER = 0.1.1
5252

5353
# Using the PROJECT_BRIEF tag one can provide an optional one line description
5454
# for a project that appears at the top of each page and should give viewer a

R/bart.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
206206
if (previous_bart_model$model_params$include_mean_forest) {
207207
previous_forest_samples_mean <- previous_bart_model$mean_forests
208208
} else previous_forest_samples_mean <- NULL
209-
if (previous_bart_model$model_params$include_mean_forest) {
209+
if (previous_bart_model$model_params$include_variance_forest) {
210210
previous_forest_samples_variance <- previous_bart_model$variance_forests
211211
} else previous_forest_samples_variance <- NULL
212212
if (previous_bart_model$model_params$sample_sigma_global) {
@@ -1853,7 +1853,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
18531853
}
18541854

18551855
# Unpack covariate preprocessor
1856-
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
1856+
preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata")
18571857
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
18581858
preprocessor_metadata_string
18591859
)

demo/debug/multi_chain.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Multi Chain Demo Script
2+
3+
# Load necessary libraries
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from sklearn.model_selection import train_test_split
9+
10+
from stochtree import BARTModel
11+
12+
# Generate sample data
13+
# RNG
14+
random_seed = 1234
15+
rng = np.random.default_rng(random_seed)
16+
17+
# Generate covariates and basis
18+
n = 500
19+
p_X = 10
20+
p_W = 1
21+
X = rng.uniform(0, 1, (n, p_X))
22+
W = rng.uniform(0, 1, (n, p_W))
23+
24+
25+
# Define the outcome mean function
26+
def outcome_mean(X, W):
27+
return np.where(
28+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
29+
-7.5 * W[:, 0],
30+
np.where(
31+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
32+
-2.5 * W[:, 0],
33+
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
34+
),
35+
)
36+
37+
38+
# Generate outcome
39+
f_XW = outcome_mean(X, W)
40+
epsilon = rng.normal(0, 1, n)
41+
y = f_XW + epsilon
42+
43+
# Test-train split
44+
sample_inds = np.arange(n)
45+
train_inds, test_inds = train_test_split(
46+
sample_inds, test_size=0.5, random_state=random_seed
47+
)
48+
X_train = X[train_inds, :]
49+
X_test = X[test_inds, :]
50+
basis_train = W[train_inds, :]
51+
basis_test = W[test_inds, :]
52+
y_train = y[train_inds]
53+
y_test = y[test_inds]
54+
55+
# Run the GFR algorithm for a small number of iterations
56+
general_model_params = {"random_seed": -1}
57+
mean_forest_model_params = {"num_trees": 20}
58+
num_warmstart = 10
59+
num_mcmc = 10
60+
bart_model = BARTModel()
61+
bart_model.sample(
62+
X_train=X_train,
63+
y_train=y_train,
64+
leaf_basis_train=basis_train,
65+
X_test=X_test,
66+
leaf_basis_test=basis_test,
67+
num_gfr=num_warmstart,
68+
num_mcmc=0,
69+
general_params=general_model_params,
70+
mean_forest_params=mean_forest_model_params,
71+
)
72+
bart_model_json = bart_model.to_json()
73+
74+
# Run several BART MCMC samples from the last GFR forest
75+
bart_model_2 = BARTModel()
76+
bart_model_2.sample(
77+
X_train=X_train,
78+
y_train=y_train,
79+
leaf_basis_train=basis_train,
80+
X_test=X_test,
81+
leaf_basis_test=basis_test,
82+
num_gfr=0,
83+
num_mcmc=num_mcmc,
84+
previous_model_json=bart_model_json,
85+
previous_model_warmstart_sample_num=num_warmstart - 1,
86+
general_params=general_model_params,
87+
mean_forest_params=mean_forest_model_params,
88+
)
89+
90+
# Run several BART MCMC samples from the second-to-last GFR forest
91+
bart_model_3 = BARTModel()
92+
bart_model_3.sample(
93+
X_train=X_train,
94+
y_train=y_train,
95+
leaf_basis_train=basis_train,
96+
X_test=X_test,
97+
leaf_basis_test=basis_test,
98+
num_gfr=0,
99+
num_mcmc=num_mcmc,
100+
previous_model_json=bart_model_json,
101+
previous_model_warmstart_sample_num=num_warmstart - 2,
102+
general_params=general_model_params,
103+
mean_forest_params=mean_forest_model_params,
104+
)
105+
106+
# Run several BART MCMC samples from root
107+
bart_model_4 = BARTModel()
108+
bart_model_4.sample(
109+
X_train=X_train,
110+
y_train=y_train,
111+
leaf_basis_train=basis_train,
112+
X_test=X_test,
113+
leaf_basis_test=basis_test,
114+
num_gfr=0,
115+
num_mcmc=num_mcmc,
116+
general_params=general_model_params,
117+
mean_forest_params=mean_forest_model_params,
118+
)
119+
120+
# Inspect the model outputs
121+
y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test)
122+
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
123+
y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test)
124+
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
125+
y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test)
126+
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
127+
y_df = pd.DataFrame(
128+
np.concatenate(
129+
(y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)),
130+
axis=1,
131+
),
132+
columns=["First Chain", "Second Chain", "Third Chain", "Outcome"],
133+
)
134+
135+
# Compare first warm-start chain to root chain with equal number of MCMC draws
136+
sns.scatterplot(data=y_df, x="First Chain", y="Third Chain")
137+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
138+
plt.show()
139+
140+
# Compare first warm-start chain to outcome
141+
sns.scatterplot(data=y_df, x="First Chain", y="Outcome")
142+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
143+
plt.show()
144+
145+
# Compare root chain to outcome
146+
sns.scatterplot(data=y_df, x="Third Chain", y="Outcome")
147+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
148+
plt.show()
149+
150+
# Compute RMSEs
151+
rmse_1 = np.sqrt(
152+
np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test))
153+
)
154+
rmse_2 = np.sqrt(
155+
np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test))
156+
)
157+
rmse_3 = np.sqrt(
158+
np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test))
159+
)
160+
print(
161+
"Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(
162+
rmse_1, rmse_2, rmse_3
163+
)
164+
)

demo/debug/parallel_multi_chain.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Multi Chain Demo Script
2+
3+
# Load necessary libraries
4+
from multiprocessing import Pool, cpu_count
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
import seaborn as sns
10+
from sklearn.model_selection import train_test_split
11+
12+
from stochtree import BARTModel
13+
14+
15+
def fit_bart(
16+
model_string,
17+
X_train,
18+
y_train,
19+
basis_train,
20+
X_test,
21+
basis_test,
22+
num_mcmc,
23+
gen_param_list,
24+
mean_list,
25+
i,
26+
):
27+
bart_model = BARTModel()
28+
bart_model.sample(
29+
X_train=X_train,
30+
y_train=y_train,
31+
leaf_basis_train=basis_train,
32+
X_test=X_test,
33+
leaf_basis_test=basis_test,
34+
num_gfr=0,
35+
num_mcmc=num_mcmc,
36+
previous_model_json=model_string,
37+
previous_model_warmstart_sample_num=i,
38+
general_params=gen_param_list,
39+
mean_forest_params=mean_list,
40+
)
41+
return (bart_model.to_json(), bart_model.y_hat_test)
42+
43+
44+
def bart_warmstart_parallel(X_train, y_train, basis_train, X_test, basis_test):
45+
# Run the GFR algorithm for a small number of iterations
46+
general_model_params = {"random_seed": -1}
47+
mean_forest_model_params = {"num_trees": 100}
48+
num_warmstart = 10
49+
num_mcmc = 100
50+
bart_model = BARTModel()
51+
bart_model.sample(
52+
X_train=X_train,
53+
y_train=y_train,
54+
leaf_basis_train=basis_train,
55+
X_test=X_test,
56+
leaf_basis_test=basis_test,
57+
num_gfr=num_warmstart,
58+
num_mcmc=0,
59+
general_params=general_model_params,
60+
mean_forest_params=mean_forest_model_params,
61+
)
62+
bart_model_json = bart_model.to_json()
63+
64+
# Warm-start multiple BART fits from a different GFR forest
65+
process_tasks = [
66+
(
67+
bart_model_json,
68+
X_train,
69+
y_train,
70+
basis_train,
71+
X_test,
72+
basis_test,
73+
num_mcmc,
74+
general_model_params,
75+
mean_forest_model_params,
76+
i,
77+
)
78+
for i in range(4)
79+
]
80+
num_processes = cpu_count()
81+
with Pool(processes=num_processes) as pool:
82+
results = pool.starmap(fit_bart, process_tasks)
83+
84+
# Extract separate outputs as separate lists
85+
bart_model_json_list, bart_model_pred_list = zip(*results)
86+
87+
# Process results
88+
combined_bart_model = BARTModel()
89+
combined_bart_model.from_json_string_list(bart_model_json_list)
90+
combined_bart_preds = bart_model_pred_list[0]
91+
for i in range(1, len(bart_model_pred_list)):
92+
combined_bart_preds = np.concatenate(
93+
(combined_bart_preds, bart_model_pred_list[i]), axis=1
94+
)
95+
96+
return (combined_bart_model, combined_bart_preds)
97+
98+
99+
if __name__ == "__main__":
100+
# RNG
101+
random_seed = 1234
102+
rng = np.random.default_rng(random_seed)
103+
104+
# Generate covariates and basis
105+
n = 1000
106+
p_X = 10
107+
p_W = 1
108+
X = rng.uniform(0, 1, (n, p_X))
109+
W = rng.uniform(0, 1, (n, p_W))
110+
111+
# Define the outcome mean function
112+
def outcome_mean(X, W):
113+
return np.where(
114+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
115+
-7.5 * W[:, 0],
116+
np.where(
117+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
118+
-2.5 * W[:, 0],
119+
np.where(
120+
(X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]
121+
),
122+
),
123+
)
124+
125+
# Generate outcome
126+
f_XW = outcome_mean(X, W)
127+
epsilon = rng.normal(0, 1, n)
128+
y = f_XW + epsilon
129+
130+
# Test-train split
131+
sample_inds = np.arange(n)
132+
train_inds, test_inds = train_test_split(
133+
sample_inds, test_size=0.2, random_state=random_seed
134+
)
135+
X_train = X[train_inds, :]
136+
X_test = X[test_inds, :]
137+
basis_train = W[train_inds, :]
138+
basis_test = W[test_inds, :]
139+
y_train = y[train_inds]
140+
y_test = y[test_inds]
141+
142+
# Run the parallel BART
143+
combined_bart, combined_bart_preds = bart_warmstart_parallel(
144+
X_train, y_train, basis_train, X_test, basis_test
145+
)
146+
147+
# Inspect the model outputs
148+
y_hat_mcmc = combined_bart.predict(X_test, basis_test)
149+
y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True)
150+
y_df = pd.DataFrame(
151+
np.concatenate((y_avg_mcmc, np.expand_dims(y_test, axis=1)), axis=1),
152+
columns=["Average BART Predictions", "Outcome"],
153+
)
154+
155+
# Compare first warm-start chain to outcome
156+
sns.scatterplot(data=y_df, x="Average BART Predictions", y="Outcome")
157+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
158+
plt.show()
159+
160+
# Compare cached predictions to deserialized predictions for first chain
161+
chain_index = 0
162+
num_mcmc = 100
163+
offset_index = num_mcmc * chain_index
164+
chain_inds = slice(offset_index, (offset_index + num_mcmc))
165+
chain_1_preds_original = np.squeeze(combined_bart_preds[chain_inds]).mean(
166+
axis=1, keepdims=True
167+
)
168+
chain_1_preds_reloaded = np.squeeze(y_hat_mcmc[chain_inds]).mean(
169+
axis=1, keepdims=True
170+
)
171+
chain_df = pd.DataFrame(
172+
np.concatenate((chain_1_preds_reloaded, chain_1_preds_original), axis=1),
173+
columns=["New Predictions", "Original Predictions"],
174+
)
175+
sns.scatterplot(data=chain_df, x="New Predictions", y="Original Predictions")
176+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
177+
plt.show()

0 commit comments

Comments
 (0)