|
| 1 | +# Supervised Learning Demo Script |
| 2 | + |
| 3 | +# Load necessary libraries |
| 4 | +import numpy as np |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from stochtree import BARTModel |
| 7 | +from sklearn.model_selection import train_test_split |
| 8 | +from sklearn.metrics import roc_curve |
| 9 | + |
| 10 | +# Generate sample data |
| 11 | +# RNG |
| 12 | +rng = np.random.default_rng() |
| 13 | + |
| 14 | +# Generate covariates and basis |
| 15 | +n = 1000 |
| 16 | +p_X = 10 |
| 17 | +p_basis = 1 |
| 18 | +X = rng.uniform(0, 1, (n, p_X)) |
| 19 | +basis = rng.uniform(0, 1, (n, p_basis)) |
| 20 | + |
| 21 | +# Define the outcome mean function |
| 22 | +def outcome_mean(X, basis = None): |
| 23 | + if basis is not None: |
| 24 | + return np.where( |
| 25 | + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * basis[:,0], |
| 26 | + np.where( |
| 27 | + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * basis[:,0], |
| 28 | + np.where( |
| 29 | + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * basis[:,0], |
| 30 | + 7.5 * basis[:,0] |
| 31 | + ) |
| 32 | + ) |
| 33 | + ) |
| 34 | + else: |
| 35 | + return np.where( |
| 36 | + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * X[:,1], |
| 37 | + np.where( |
| 38 | + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * X[:,1], |
| 39 | + np.where( |
| 40 | + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * X[:,1], |
| 41 | + 7.5 * X[:,1] |
| 42 | + ) |
| 43 | + ) |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +# Generate outcome |
| 48 | +epsilon = rng.normal(0, 1, n) |
| 49 | +w = outcome_mean(X, basis) + epsilon |
| 50 | +# w = outcome_mean(X) + epsilon |
| 51 | +y = np.where(w > 0, 1, 0) |
| 52 | + |
| 53 | +# Test-train split |
| 54 | +sample_inds = np.arange(n) |
| 55 | +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) |
| 56 | +X_train = X[train_inds,:] |
| 57 | +X_test = X[test_inds,:] |
| 58 | +basis_train = basis[train_inds,:] |
| 59 | +basis_test = basis[test_inds,:] |
| 60 | +w_train = w[train_inds] |
| 61 | +w_test = w[test_inds] |
| 62 | +y_train = y[train_inds] |
| 63 | +y_test = y[test_inds] |
| 64 | + |
| 65 | +# Construct parameter lists |
| 66 | +general_params = { |
| 67 | + 'probit_outcome_model': True, |
| 68 | + 'sample_sigma2_global': False |
| 69 | +} |
| 70 | + |
| 71 | +# Run BART |
| 72 | +num_gfr = 10 |
| 73 | +num_mcmc = 100 |
| 74 | +bart_model = BARTModel() |
| 75 | +bart_model.sample(X_train=X_train, y_train=y_train, leaf_basis_train=basis_train, |
| 76 | + X_test=X_test, leaf_basis_test=basis_test, num_gfr=num_gfr, |
| 77 | + num_burnin=0, num_mcmc=num_mcmc, general_params=general_params) |
| 78 | +# bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=num_gfr, |
| 79 | +# num_burnin=0, num_mcmc=num_mcmc, general_params=general_params) |
| 80 | + |
| 81 | +# Inspect the MCMC (BART) samples |
| 82 | +w_hat_test = np.squeeze(bart_model.y_hat_test).mean(axis = 1) |
| 83 | +plt.scatter(w_hat_test, w_test, color="black") |
| 84 | +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) |
| 85 | +plt.xlabel("Predicted") |
| 86 | +plt.ylabel("Actual") |
| 87 | +plt.title("Probit scale latent outcome") |
| 88 | +plt.show() |
| 89 | + |
| 90 | +# Compute prediction accuracy |
| 91 | +preds_test = w_hat_test > 0 |
| 92 | +print(f"Test set accuracy: {np.mean(y_test == preds_test):.3f}") |
| 93 | + |
| 94 | +# Present a ROC curve |
| 95 | +fpr_list = list() |
| 96 | +tpr_list = list() |
| 97 | +threshold_list = list() |
| 98 | +for i in range(num_mcmc): |
| 99 | + fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1) |
| 100 | + fpr_list.append(fpr) |
| 101 | + tpr_list.append(tpr) |
| 102 | + threshold_list.append(thresholds) |
| 103 | +fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, w_hat_test, pos_label=1) |
| 104 | +for i in range(num_mcmc): |
| 105 | + plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 1.25) |
| 106 | +plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 2.0) |
| 107 | +plt.axline((0, 0), slope=1, color="red", linestyle='dashed', linewidth=1.5) |
| 108 | +plt.xlabel("False Positive Rate") |
| 109 | +plt.ylabel("True Positive Rate") |
| 110 | +plt.xlim(0, 1) |
| 111 | +plt.ylim(0, 1) |
| 112 | +plt.show() |
0 commit comments