Skip to content

Commit 7b3a5ae

Browse files
committed
Updated python probit BART interface and demos
1 parent 36944b3 commit 7b3a5ae

File tree

4 files changed

+117
-7
lines changed

4 files changed

+117
-7
lines changed

demo/debug/supervised_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def outcome_mean(X, W):
4444

4545
# Test-train split
4646
sample_inds = np.arange(n)
47-
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
47+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2)
4848
X_train = X[train_inds,:]
4949
X_test = X[test_inds,:]
5050
basis_train = W[train_inds,:]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Supervised Learning Demo Script
2+
3+
# Load necessary libraries
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from stochtree import BARTModel
7+
from sklearn.model_selection import train_test_split
8+
from sklearn.metrics import roc_curve
9+
10+
# Generate sample data
11+
# RNG
12+
rng = np.random.default_rng()
13+
14+
# Generate covariates and basis
15+
n = 1000
16+
p_X = 10
17+
p_basis = 1
18+
X = rng.uniform(0, 1, (n, p_X))
19+
basis = rng.uniform(0, 1, (n, p_basis))
20+
21+
# Define the outcome mean function
22+
def outcome_mean(X, basis = None):
23+
if basis is not None:
24+
return np.where(
25+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * basis[:,0],
26+
np.where(
27+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * basis[:,0],
28+
np.where(
29+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * basis[:,0],
30+
7.5 * basis[:,0]
31+
)
32+
)
33+
)
34+
else:
35+
return np.where(
36+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * X[:,1],
37+
np.where(
38+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * X[:,1],
39+
np.where(
40+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * X[:,1],
41+
7.5 * X[:,1]
42+
)
43+
)
44+
)
45+
46+
47+
# Generate outcome
48+
epsilon = rng.normal(0, 1, n)
49+
w = outcome_mean(X, basis) + epsilon
50+
# w = outcome_mean(X) + epsilon
51+
y = np.where(w > 0, 1, 0)
52+
53+
# Test-train split
54+
sample_inds = np.arange(n)
55+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2)
56+
X_train = X[train_inds,:]
57+
X_test = X[test_inds,:]
58+
basis_train = basis[train_inds,:]
59+
basis_test = basis[test_inds,:]
60+
w_train = w[train_inds]
61+
w_test = w[test_inds]
62+
y_train = y[train_inds]
63+
y_test = y[test_inds]
64+
65+
# Construct parameter lists
66+
general_params = {
67+
'probit_outcome_model': True,
68+
'sample_sigma2_global': False
69+
}
70+
71+
# Run BART
72+
num_gfr = 10
73+
num_mcmc = 100
74+
bart_model = BARTModel()
75+
bart_model.sample(X_train=X_train, y_train=y_train, leaf_basis_train=basis_train,
76+
X_test=X_test, leaf_basis_test=basis_test, num_gfr=num_gfr,
77+
num_burnin=0, num_mcmc=num_mcmc, general_params=general_params)
78+
# bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=num_gfr,
79+
# num_burnin=0, num_mcmc=num_mcmc, general_params=general_params)
80+
81+
# Inspect the MCMC (BART) samples
82+
w_hat_test = np.squeeze(bart_model.y_hat_test).mean(axis = 1)
83+
plt.scatter(w_hat_test, w_test, color="black")
84+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
85+
plt.xlabel("Predicted")
86+
plt.ylabel("Actual")
87+
plt.title("Probit scale latent outcome")
88+
plt.show()
89+
90+
# Compute prediction accuracy
91+
preds_test = w_hat_test > 0
92+
print(f"Test set accuracy: {np.mean(y_test == preds_test):.3f}")
93+
94+
# Present a ROC curve
95+
fpr_list = list()
96+
tpr_list = list()
97+
threshold_list = list()
98+
for i in range(num_mcmc):
99+
fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1)
100+
fpr_list.append(fpr)
101+
tpr_list.append(tpr)
102+
threshold_list.append(thresholds)
103+
fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, w_hat_test, pos_label=1)
104+
for i in range(num_mcmc):
105+
plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 1.25)
106+
plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 2.0)
107+
plt.axline((0, 0), slope=1, color="red", linestyle='dashed', linewidth=1.5)
108+
plt.xlabel("False Positive Rate")
109+
plt.ylabel("True Positive Rate")
110+
plt.xlim(0, 1)
111+
plt.ylim(0, 1)
112+
plt.show()

demo/notebooks/supervised_learning_classification.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,14 @@
110110
"num_gfr = 10\n",
111111
"num_mcmc = 100\n",
112112
"bart_model = BARTModel()\n",
113-
"general_params = {\"num_chains\": 1}\n",
114-
"mean_forest_params = {\"probit_outcome_model\": True}\n",
113+
"general_params = {\"num_chains\": 1, \"probit_outcome_model\": True}\n",
115114
"bart_model.sample(\n",
116115
" X_train=X_train,\n",
117116
" y_train=y_train,\n",
118117
" X_test=X_test,\n",
119118
" num_gfr=num_gfr,\n",
120119
" num_mcmc=num_mcmc,\n",
121-
" general_params=general_params,\n",
122-
" mean_forest_params=mean_forest_params\n",
120+
" general_params=general_params\n",
123121
")"
124122
]
125123
},

stochtree/bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def sample(
187187
"keep_gfr": False,
188188
"keep_every": 1,
189189
"num_chains": 1,
190+
"probit_outcome_model": False,
190191
}
191192
general_params_updated = _preprocess_params(
192193
general_params_default, general_params
@@ -205,7 +206,6 @@ def sample(
205206
"sigma2_leaf_scale": None,
206207
"keep_vars": None,
207208
"drop_vars": None,
208-
"probit_outcome_model": False,
209209
}
210210
mean_forest_params_updated = _preprocess_params(
211211
mean_forest_params_default, mean_forest_params
@@ -243,6 +243,7 @@ def sample(
243243
keep_gfr = general_params_updated["keep_gfr"]
244244
keep_every = general_params_updated["keep_every"]
245245
num_chains = general_params_updated["num_chains"]
246+
self.probit_outcome_model = general_params_updated["probit_outcome_model"]
246247

247248
# 2. Mean forest parameters
248249
num_trees_mean = mean_forest_params_updated["num_trees"]
@@ -256,7 +257,6 @@ def sample(
256257
b_leaf = mean_forest_params_updated["sigma2_leaf_scale"]
257258
keep_vars_mean = mean_forest_params_updated["keep_vars"]
258259
drop_vars_mean = mean_forest_params_updated["drop_vars"]
259-
self.probit_outcome_model = mean_forest_params_updated["probit_outcome_model"]
260260

261261
# 3. Variance forest parameters
262262
num_trees_variance = variance_forest_params_updated["num_trees"]

0 commit comments

Comments
 (0)