@@ -134,9 +134,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
134
134
xlab = "predicted", ylab = "actual", main = "Treatment effect")
135
135
abline(0,1,col="red",lty=3,lwd=3)
136
136
sigma_observed <- var(y-E_XZ)
137
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
138
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
139
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
137
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
138
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
139
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
140
140
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
141
141
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
142
142
```
@@ -184,9 +184,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test,
184
184
xlab = "predicted", ylab = "actual", main = "Treatment effect")
185
185
abline(0,1,col="red",lty=3,lwd=3)
186
186
sigma_observed <- var(y-E_XZ)
187
- plot_bounds <- c(min(c(bcf_model_root$sigma2_samples , sigma_observed)),
188
- max(c(bcf_model_root$sigma2_samples , sigma_observed)))
189
- plot(bcf_model_root$sigma2_samples , ylim = plot_bounds,
187
+ plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples , sigma_observed)),
188
+ max(c(bcf_model_root$sigma2_global_samples , sigma_observed)))
189
+ plot(bcf_model_root$sigma2_global_samples , ylim = plot_bounds,
190
190
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
191
191
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
192
192
```
@@ -303,9 +303,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
303
303
xlab = "predicted", ylab = "actual", main = "Treatment effect")
304
304
abline(0,1,col="red",lty=3,lwd=3)
305
305
sigma_observed <- var(y-E_XZ)
306
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
307
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
308
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
306
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
307
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
308
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
309
309
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
310
310
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
311
311
```
@@ -353,9 +353,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test,
353
353
xlab = "predicted", ylab = "actual", main = "Treatment effect")
354
354
abline(0,1,col="red",lty=3,lwd=3)
355
355
sigma_observed <- var(y-E_XZ)
356
- plot_bounds <- c(min(c(bcf_model_root$sigma2_samples , sigma_observed)),
357
- max(c(bcf_model_root$sigma2_samples , sigma_observed)))
358
- plot(bcf_model_root$sigma2_samples , ylim = plot_bounds,
356
+ plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples , sigma_observed)),
357
+ max(c(bcf_model_root$sigma2_global_samples , sigma_observed)))
358
+ plot(bcf_model_root$sigma2_global_samples , ylim = plot_bounds,
359
359
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
360
360
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
361
361
```
@@ -472,9 +472,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
472
472
xlab = "predicted", ylab = "actual", main = "Treatment effect")
473
473
abline(0,1,col="red",lty=3,lwd=3)
474
474
sigma_observed <- var(y-E_XZ)
475
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
476
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
477
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
475
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
476
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
477
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
478
478
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
479
479
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
480
480
```
@@ -522,9 +522,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test,
522
522
xlab = "predicted", ylab = "actual", main = "Treatment effect")
523
523
abline(0,1,col="red",lty=3,lwd=3)
524
524
sigma_observed <- var(y-E_XZ)
525
- plot_bounds <- c(min(c(bcf_model_root$sigma2_samples , sigma_observed)),
526
- max(c(bcf_model_root$sigma2_samples , sigma_observed)))
527
- plot(bcf_model_root$sigma2_samples , ylim = plot_bounds,
525
+ plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples , sigma_observed)),
526
+ max(c(bcf_model_root$sigma2_global_samples , sigma_observed)))
527
+ plot(bcf_model_root$sigma2_global_samples , ylim = plot_bounds,
528
528
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
529
529
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
530
530
```
@@ -639,9 +639,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
639
639
xlab = "predicted", ylab = "actual", main = "Treatment effect")
640
640
abline(0,1,col="red",lty=3,lwd=3)
641
641
sigma_observed <- var(y-E_XZ)
642
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
643
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
644
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
642
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
643
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
644
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
645
645
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
646
646
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
647
647
```
@@ -689,9 +689,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test,
689
689
xlab = "predicted", ylab = "actual", main = "Treatment effect")
690
690
abline(0,1,col="red",lty=3,lwd=3)
691
691
sigma_observed <- var(y-E_XZ)
692
- plot_bounds <- c(min(c(bcf_model_root$sigma2_samples , sigma_observed)),
693
- max(c(bcf_model_root$sigma2_samples , sigma_observed)))
694
- plot(bcf_model_root$sigma2_samples , ylim = plot_bounds,
692
+ plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples , sigma_observed)),
693
+ max(c(bcf_model_root$sigma2_global_samples , sigma_observed)))
694
+ plot(bcf_model_root$sigma2_global_samples , ylim = plot_bounds,
695
695
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
696
696
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
697
697
```
@@ -806,9 +806,9 @@ plot(rowMeans(bcf_model_warmstart$rfx_preds_test), rfx_term_test,
806
806
xlab = "predicted", ylab = "actual", main = "Random effects terms")
807
807
abline(0,1,col="red",lty=3,lwd=3)
808
808
sigma_observed <- var(y-E_XZ-rfx_term)
809
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
810
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
811
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
809
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
810
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
811
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
812
812
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
813
813
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
814
814
```
@@ -919,9 +919,9 @@ plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test,
919
919
xlab = "predicted", ylab = "actual", main = "Outcome")
920
920
abline(0,1,col="red",lty=3,lwd=3)
921
921
sigma_observed <- var(y-E_XZ)
922
- plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples , sigma_observed)),
923
- max(c(bcf_model_mcmc$sigma2_samples , sigma_observed)))
924
- plot(bcf_model_mcmc$sigma2_samples , ylim = plot_bounds,
922
+ plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_global_samples , sigma_observed)),
923
+ max(c(bcf_model_mcmc$sigma2_global_samples , sigma_observed)))
924
+ plot(bcf_model_mcmc$sigma2_global_samples , ylim = plot_bounds,
925
925
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
926
926
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
927
927
```
@@ -981,9 +981,9 @@ plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test,
981
981
xlab = "predicted", ylab = "actual", main = "Outcome")
982
982
abline(0,1,col="red",lty=3,lwd=3)
983
983
sigma_observed <- var(y-E_XZ)
984
- plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples , sigma_observed)),
985
- max(c(bcf_model_mcmc$sigma2_samples , sigma_observed)))
986
- plot(bcf_model_mcmc$sigma2_samples , ylim = plot_bounds,
984
+ plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_global_samples , sigma_observed)),
985
+ max(c(bcf_model_mcmc$sigma2_global_samples , sigma_observed)))
986
+ plot(bcf_model_mcmc$sigma2_global_samples , ylim = plot_bounds,
987
987
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
988
988
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
989
989
```
@@ -1043,9 +1043,9 @@ plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test,
1043
1043
xlab = "predicted", ylab = "actual", main = "Outcome")
1044
1044
abline(0,1,col="red",lty=3,lwd=3)
1045
1045
sigma_observed <- var(y-E_XZ)
1046
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
1047
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
1048
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
1046
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
1047
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
1048
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
1049
1049
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1050
1050
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1051
1051
```
@@ -1105,9 +1105,9 @@ plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test,
1105
1105
xlab = "predicted", ylab = "actual", main = "Outcome")
1106
1106
abline(0,1,col="red",lty=3,lwd=3)
1107
1107
sigma_observed <- var(y-E_XZ)
1108
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
1109
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
1110
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
1108
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
1109
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
1110
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
1111
1111
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1112
1112
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1113
1113
```
@@ -1133,6 +1133,193 @@ test_outcome_mean <- rowMeans(bcf_model_warmstart$y_hat_test)
1133
1133
sqrt(mean((y_test - test_outcome_mean)^2))
1134
1134
```
1135
1135
1136
+ ## Demo 7: Probit Outcome Model, Heterogeneous Treatment Effect
1137
+
1138
+ We consider a modified version of a data generating process from @hahn2020bayesian :
1139
+
1140
+ \begin{equation* }
1141
+ \begin{aligned}
1142
+ y &= \mathbb{1}\left(w > 0\right)\\
1143
+ w &= \mu(X) + \tau(X) Z + \epsilon\\
1144
+ \epsilon &\sim N\left(0,1\right)\\
1145
+ \mu(X) &= 1 + g(X) + 6 \lvert X_3 - 1 \rvert\\
1146
+ \tau(X) &= 1 + 2 X_2 X_4\\
1147
+ g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\
1148
+ s_ {\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\
1149
+ \pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_ {\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\
1150
+ X_1,X_2,X_3 &\sim N\left(0,1\right)\\
1151
+ X_4 &\sim \text{Bernoulli}(1/2)\\
1152
+ X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\
1153
+ U &\sim \text{Uniform}\left(0,1\right)\\
1154
+ Z &\sim \text{Bernoulli}\left(\pi(X)\right)
1155
+ \end{aligned}
1156
+ \end{equation* }
1157
+
1158
+ ### Simulation
1159
+
1160
+ We draw from the DGP defined above
1161
+
1162
+ ``` {r}
1163
+ n <- 2000
1164
+ x1 <- rnorm(n)
1165
+ x2 <- rnorm(n)
1166
+ x3 <- rnorm(n)
1167
+ x4 <- as.numeric(rbinom(n,1,0.5))
1168
+ x5 <- as.numeric(sample(1:3,n,replace=TRUE))
1169
+ X <- cbind(x1,x2,x3,x4,x5)
1170
+ p <- ncol(X)
1171
+ mu_x <- mu1(X)
1172
+ tau_x <- tau2(X)
1173
+ pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
1174
+ Z <- rbinom(n,1,pi_x)
1175
+ E_XZ <- mu_x + Z*tau_x
1176
+ w <- E_XZ + rnorm(n, 0, 1)
1177
+ y <- 1*(w > 0)
1178
+ delta_x <- pnorm(mu_x + tau_x) - pnorm(mu_x)
1179
+ X <- as.data.frame(X)
1180
+ X$x4 <- factor(X$x4, ordered = TRUE)
1181
+ X$x5 <- factor(X$x5, ordered = TRUE)
1182
+
1183
+ # Split data into test and train sets
1184
+ test_set_pct <- 0.2
1185
+ n_test <- round(test_set_pct*n)
1186
+ n_train <- n - n_test
1187
+ test_inds <- sort(sample(1:n, n_test, replace = FALSE))
1188
+ train_inds <- (1:n)[!((1:n) %in% test_inds)]
1189
+ X_test <- X[test_inds,]
1190
+ X_train <- X[train_inds,]
1191
+ pi_test <- pi_x[test_inds]
1192
+ pi_train <- pi_x[train_inds]
1193
+ Z_test <- Z[test_inds]
1194
+ Z_train <- Z[train_inds]
1195
+ w_test <- w[test_inds]
1196
+ w_train <- w[train_inds]
1197
+ y_test <- y[test_inds]
1198
+ y_train <- y[train_inds]
1199
+ mu_test <- mu_x[test_inds]
1200
+ mu_train <- mu_x[train_inds]
1201
+ tau_test <- tau_x[test_inds]
1202
+ tau_train <- tau_x[train_inds]
1203
+ delta_x_test <- delta_x[test_inds]
1204
+ delta_x_train <- delta_x[train_inds]
1205
+ ```
1206
+
1207
+ ### Sampling and Analysis
1208
+
1209
+ #### Warmstart
1210
+
1211
+ We first simulate from an ensemble model of $y \mid X$ using "warm-start"
1212
+ initialization samples (@krantsevich2023stochastic ). This is the default in
1213
+ ` stochtree ` .
1214
+
1215
+ ``` {r}
1216
+ num_gfr <- 10
1217
+ num_burnin <- 0
1218
+ num_mcmc <- 100
1219
+ num_samples <- num_gfr + num_burnin + num_mcmc
1220
+ general_params <- list(keep_every = 5,
1221
+ probit_outcome_model = T,
1222
+ sample_sigma2_global = F)
1223
+ prognostic_forest_params <- list(sample_sigma2_leaf = F)
1224
+ treatment_effect_forest_params <- list(sample_sigma2_leaf = F)
1225
+ bcf_model_warmstart <- bcf(
1226
+ X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train,
1227
+ X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
1228
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
1229
+ general_params = general_params, prognostic_forest_params = prognostic_forest_params,
1230
+ treatment_effect_forest_params = treatment_effect_forest_params
1231
+ )
1232
+ ```
1233
+
1234
+ Inspect the BART samples that were initialized with an XBART warm-start
1235
+
1236
+ ``` {r}
1237
+ mu_hat_test <- rowMeans(bcf_model_warmstart$mu_hat_test)
1238
+ plot(mu_hat_test, mu_test, xlab = "predicted",
1239
+ ylab = "actual", main = "Prognostic function")
1240
+ abline(0,1,col="red",lty=3,lwd=3)
1241
+ tau_hat_test <- rowMeans(bcf_model_warmstart$tau_hat_test)
1242
+ plot(tau_hat_test, tau_test, xlab = "predicted",
1243
+ ylab = "actual", main = "Treatment effect")
1244
+ abline(0,1,col="red",lty=3,lwd=3)
1245
+ delta_x_hat_test <- pnorm(mu_hat_test+tau_hat_test) - pnorm(mu_hat_test)
1246
+ plot(delta_x_hat_test, delta_x_test,
1247
+ xlab = "predicted", ylab = "actual", main = "Distributional treatment\neffect")
1248
+ abline(0,1,col="red",lty=3,lwd=3)
1249
+ ```
1250
+
1251
+ Examine test set interval coverage
1252
+
1253
+ ``` {r}
1254
+ test_lb <- apply(
1255
+ pnorm(bcf_model_warmstart$mu_hat_test + bcf_model_warmstart$tau_hat_test) -
1256
+ pnorm(bcf_model_warmstart$mu_hat_test), 1, quantile, 0.025)
1257
+ test_ub <- apply(
1258
+ pnorm(bcf_model_warmstart$mu_hat_test + bcf_model_warmstart$tau_hat_test) -
1259
+ pnorm(bcf_model_warmstart$mu_hat_test), 1, quantile, 0.975)
1260
+ cover <- (
1261
+ (test_lb <= delta_x_test) &
1262
+ (test_ub >= delta_x_test)
1263
+ )
1264
+ mean(cover)
1265
+ ```
1266
+
1267
+ #### BART MCMC without Warmstart
1268
+
1269
+ Next, we simulate from this ensemble model without any warm-start initialization.
1270
+
1271
+ ``` {r}
1272
+ num_gfr <- 0
1273
+ num_burnin <- 2000
1274
+ num_mcmc <- 100
1275
+ num_samples <- num_gfr + num_burnin + num_mcmc
1276
+ general_params <- list(keep_every = 5,
1277
+ probit_outcome_model = T,
1278
+ sample_sigma2_global = F)
1279
+ prognostic_forest_params <- list(sample_sigma2_leaf = F)
1280
+ treatment_effect_forest_params <- list(sample_sigma2_leaf = F)
1281
+ bcf_model_root <- bcf(
1282
+ X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train,
1283
+ X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
1284
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
1285
+ general_params = general_params, prognostic_forest_params = prognostic_forest_params,
1286
+ treatment_effect_forest_params = treatment_effect_forest_params
1287
+ )
1288
+ ```
1289
+
1290
+ Inspect the BART samples that were initialized with an XBART warm-start
1291
+
1292
+ ``` {r}
1293
+ mu_hat_test <- rowMeans(bcf_model_root$mu_hat_test)
1294
+ plot(mu_hat_test, mu_test, xlab = "predicted",
1295
+ ylab = "actual", main = "Prognostic function")
1296
+ abline(0,1,col="red",lty=3,lwd=3)
1297
+ tau_hat_test <- rowMeans(bcf_model_root$tau_hat_test)
1298
+ plot(tau_hat_test, tau_test, xlab = "predicted",
1299
+ ylab = "actual", main = "Treatment effect")
1300
+ abline(0,1,col="red",lty=3,lwd=3)
1301
+ delta_x_hat_test <- pnorm(mu_hat_test+tau_hat_test) - pnorm(mu_hat_test)
1302
+ plot(delta_x_hat_test, delta_x_test,
1303
+ xlab = "predicted", ylab = "actual", main = "Distributional treatment\neffect")
1304
+ abline(0,1,col="red",lty=3,lwd=3)
1305
+ ```
1306
+
1307
+ Examine test set interval coverage
1308
+
1309
+ ``` {r}
1310
+ test_lb <- apply(
1311
+ pnorm(bcf_model_root$mu_hat_test + bcf_model_root$tau_hat_test) -
1312
+ pnorm(bcf_model_root$mu_hat_test), 1, quantile, 0.025)
1313
+ test_ub <- apply(
1314
+ pnorm(bcf_model_root$mu_hat_test + bcf_model_root$tau_hat_test) -
1315
+ pnorm(bcf_model_root$mu_hat_test), 1, quantile, 0.975)
1316
+ cover <- (
1317
+ (test_lb <= delta_x_test) &
1318
+ (test_ub >= delta_x_test)
1319
+ )
1320
+ mean(cover)
1321
+ ```
1322
+
1136
1323
# Continuous Treatment
1137
1324
1138
1325
## Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect
@@ -1230,9 +1417,9 @@ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
1230
1417
xlab = "predicted", ylab = "actual", main = "Treatment effect")
1231
1418
abline(0,1,col="red",lty=3,lwd=3)
1232
1419
sigma_observed <- var(y-E_XZ)
1233
- plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples , sigma_observed)),
1234
- max(c(bcf_model_warmstart$sigma2_samples , sigma_observed)))
1235
- plot(bcf_model_warmstart$sigma2_samples , ylim = plot_bounds,
1420
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)),
1421
+ max(c(bcf_model_warmstart$sigma2_global_samples , sigma_observed)))
1422
+ plot(bcf_model_warmstart$sigma2_global_samples , ylim = plot_bounds,
1236
1423
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1237
1424
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1238
1425
```
@@ -1280,9 +1467,9 @@ plot(rowMeans(bcf_model_root$tau_hat_test), tau_test,
1280
1467
xlab = "predicted", ylab = "actual", main = "Treatment effect")
1281
1468
abline(0,1,col="red",lty=3,lwd=3)
1282
1469
sigma_observed <- var(y-E_XZ)
1283
- plot_bounds <- c(min(c(bcf_model_root$sigma2_samples , sigma_observed)),
1284
- max(c(bcf_model_root$sigma2_samples , sigma_observed)))
1285
- plot(bcf_model_root$sigma2_samples , ylim = plot_bounds,
1470
+ plot_bounds <- c(min(c(bcf_model_root$sigma2_global_samples , sigma_observed)),
1471
+ max(c(bcf_model_root$sigma2_global_samples , sigma_observed)))
1472
+ plot(bcf_model_root$sigma2_global_samples , ylim = plot_bounds,
1286
1473
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1287
1474
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1288
1475
```
0 commit comments