Skip to content

Commit 4b827b4

Browse files
committed
Updated BCF function and added vignette demonstrating how to use feature subsets
1 parent d746d92 commit 4b827b4

File tree

2 files changed

+273
-7
lines changed

2 files changed

+273
-7
lines changed

R/bcf.R

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
189189
}
190190
variable_subset_tau <- keep_vars_tau
191191
}
192-
}
193-
if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) {
192+
} else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) {
194193
if (is.character(drop_vars_tau)) {
195194
if (!all(drop_vars_tau %in% names(X_train))) {
196195
stop("drop_vars_tau includes some variable names that are not in X_train")
@@ -304,7 +303,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
304303
variable_weights_tau <- variable_weights_mu <- variable_weights
305304
variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0
306305
variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0
307-
306+
308307
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
309308
has_basis_rfx <- F
310309
num_basis_rfx <- 0
@@ -375,14 +374,14 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
375374
feature_types <- as.integer(c(feature_types,0))
376375
X_train <- cbind(X_train, pi_train)
377376
if (propensity_covariate == "mu") {
378-
variable_weights_mu <- c(variable_weights_mu, 1./num_cov_orig)
377+
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(pi_train)))
379378
variable_weights_tau <- c(variable_weights_tau, 0)
380379
} else if (propensity_covariate == "tau") {
381380
variable_weights_mu <- c(variable_weights_mu, 0)
382-
variable_weights_tau <- c(variable_weights_tau, 1./num_cov_orig)
381+
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(pi_train)))
383382
} else if (propensity_covariate == "both") {
384-
variable_weights_mu <- c(variable_weights_mu, 1./num_cov_orig)
385-
variable_weights_tau <- c(variable_weights_tau, 1./num_cov_orig)
383+
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(pi_train)))
384+
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(pi_train)))
386385
}
387386
if (has_test) X_test <- cbind(X_test, pi_test)
388387
}

vignettes/CausalInference.Rmd

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,273 @@ mean(cover)
792792

793793
It is clear that causal inference is much more difficult in the presence of **both** strong covariate-dependent prognostic effects and strong group-level random effects. In this sense, proper prior calibration for all three of the $\mu$, $\tau$ and random effects models is crucial.
794794

795+
## Demo 6: Nonlinear Outcome Model, Heterogeneous Treatment Effect using Different Features in the Prognostic and Treatment Forests
796+
797+
Here, we consider the case in which we might prefer to use only a subset of covariates in the treatment effect forest. Why might we want to do that?
798+
Well, in many cases it is plausible that some covariates (for example age, income, etc...) influence the outcome of interest in a causal problem, but do not **moderate** the treatment effect. In this case, we'd need to include these variables in the prognostic forest for deconfounding but we don't necessarily need to include them in the treatment effect forest.
799+
800+
### Simulation
801+
802+
We draw from a modified "demo 1" DGP
803+
804+
```{r}
805+
mu <- function(x) {1+g(x)+x[,1]*x[,3]-x[,2]+3*x[,3]}
806+
tau <- function(x) {1+0.5*abs(x[,1])-0.25*sin(2*x[,1])}
807+
n <- 500
808+
snr <- 2
809+
x1 <- rnorm(n)
810+
x2 <- rnorm(n)
811+
x3 <- rnorm(n)
812+
x4 <- as.numeric(rbinom(n,1,0.5))
813+
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
814+
X <- cbind(x1,x2,x3,x4,x5)
815+
p <- ncol(X)
816+
mu_x <- mu(X)
817+
tau_x <- tau(X)
818+
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
819+
Z <- rbinom(n,1,pi_x)
820+
E_XZ <- mu_x + Z*tau_x
821+
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
822+
X <- as.data.frame(X)
823+
X$x4 <- factor(X$x4, ordered = TRUE)
824+
X$x5 <- factor(X$x5, ordered = TRUE)
825+
826+
# Split data into test and train sets
827+
test_set_pct <- 0.2
828+
n_test <- round(test_set_pct*n)
829+
n_train <- n - n_test
830+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
831+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
832+
X_test <- X[test_inds,]
833+
X_train <- X[train_inds,]
834+
pi_test <- pi_x[test_inds]
835+
pi_train <- pi_x[train_inds]
836+
Z_test <- Z[test_inds]
837+
Z_train <- Z[train_inds]
838+
y_test <- y[test_inds]
839+
y_train <- y[train_inds]
840+
mu_test <- mu_x[test_inds]
841+
mu_train <- mu_x[train_inds]
842+
tau_test <- tau_x[test_inds]
843+
tau_train <- tau_x[train_inds]
844+
```
845+
846+
### Sampling and Analysis
847+
848+
#### MCMC, full covariate set in $\tau(X)$
849+
850+
Here we simulate from the model with the original MCMC sampler, using all of the covariates in both the prognostic ($\mu(X)$) and treatment effect ($\tau(X)$) forests.
851+
852+
```{r}
853+
num_gfr <- 0
854+
num_burnin <- 1000
855+
num_mcmc <- 1000
856+
num_samples <- num_gfr + num_burnin + num_mcmc
857+
bcf_model_mcmc <- bcf(
858+
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
859+
X_test = X_test, Z_test = Z_test, pi_test = pi_test,
860+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
861+
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
862+
)
863+
```
864+
865+
Inspect the burned-in samples
866+
867+
```{r}
868+
plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test,
869+
xlab = "predicted", ylab = "actual", main = "Prognostic function")
870+
abline(0,1,col="red",lty=3,lwd=3)
871+
plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test,
872+
xlab = "predicted", ylab = "actual", main = "Treatment effect")
873+
abline(0,1,col="red",lty=3,lwd=3)
874+
sigma_observed <- var(y-E_XZ)
875+
plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)),
876+
max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
877+
plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds,
878+
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
879+
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
880+
```
881+
882+
Examine test set interval coverage
883+
884+
```{r}
885+
test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
886+
test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
887+
cover <- (
888+
(test_lb <= tau_x[test_inds]) &
889+
(test_ub >= tau_x[test_inds])
890+
)
891+
mean(cover)
892+
```
893+
894+
And test set RMSE
895+
896+
```{r}
897+
test_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
898+
sqrt(mean((test_mean - tau_test)^2))
899+
```
900+
901+
#### MCMC, covariate subset in $\tau(X)$
902+
903+
Here we simulate from the model with the original MCMC sampler, using only covariate $X_1$ in the treatment effect forest.
904+
905+
```{r}
906+
num_gfr <- 0
907+
num_burnin <- 1000
908+
num_mcmc <- 1000
909+
num_samples <- num_gfr + num_burnin + num_mcmc
910+
bcf_model_mcmc <- bcf(
911+
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
912+
X_test = X_test, Z_test = Z_test, pi_test = pi_test,
913+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
914+
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F,
915+
keep_vars_tau = c("x1")
916+
)
917+
```
918+
919+
Inspect the BART samples
920+
921+
```{r}
922+
plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test,
923+
xlab = "predicted", ylab = "actual", main = "Prognostic function")
924+
abline(0,1,col="red",lty=3,lwd=3)
925+
plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test,
926+
xlab = "predicted", ylab = "actual", main = "Treatment effect")
927+
abline(0,1,col="red",lty=3,lwd=3)
928+
sigma_observed <- var(y-E_XZ)
929+
plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)),
930+
max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
931+
plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds,
932+
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
933+
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
934+
```
935+
936+
Examine test set interval coverage
937+
938+
```{r}
939+
test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
940+
test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
941+
cover <- (
942+
(test_lb <= tau_x[test_inds]) &
943+
(test_ub >= tau_x[test_inds])
944+
)
945+
mean(cover)
946+
```
947+
948+
And test set RMSE
949+
950+
```{r}
951+
test_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
952+
sqrt(mean((test_mean - tau_test)^2))
953+
```
954+
955+
#### Warmstart, full covariate set in $\tau(X)$
956+
957+
Here we simulate from the model with the warm-start sampler, using all of the covariates in both the prognostic ($\mu(X)$) and treatment effect ($\tau(X)$) forests.
958+
959+
```{r}
960+
num_gfr <- 10
961+
num_burnin <- 0
962+
num_mcmc <- 1000
963+
num_samples <- num_gfr + num_burnin + num_mcmc
964+
bcf_model_warmstart <- bcf(
965+
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
966+
X_test = X_test, Z_test = Z_test, pi_test = pi_test,
967+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
968+
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
969+
)
970+
```
971+
972+
Inspect the BART samples that were initialized with an XBART warm-start
973+
974+
```{r}
975+
plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test,
976+
xlab = "predicted", ylab = "actual", main = "Prognostic function")
977+
abline(0,1,col="red",lty=3,lwd=3)
978+
plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
979+
xlab = "predicted", ylab = "actual", main = "Treatment effect")
980+
abline(0,1,col="red",lty=3,lwd=3)
981+
sigma_observed <- var(y-E_XZ)
982+
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)),
983+
max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
984+
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds,
985+
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
986+
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
987+
```
988+
989+
Examine test set interval coverage
990+
991+
```{r}
992+
test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
993+
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
994+
cover <- (
995+
(test_lb <= tau_x[test_inds]) &
996+
(test_ub >= tau_x[test_inds])
997+
)
998+
mean(cover)
999+
```
1000+
1001+
And test set RMSE
1002+
1003+
```{r}
1004+
test_mean <- apply(bcf_model_warmstart$tau_hat_test, 1, mean)
1005+
sqrt(mean((tau_x[test_inds] - test_mean)^2))
1006+
```
1007+
1008+
#### Warmstart, covariate subset in $\tau(X)$
1009+
1010+
Here we simulate from the model with the warm-start sampler, using only covariate $X_1$ in the treatment effect forest.
1011+
1012+
```{r}
1013+
num_gfr <- 10
1014+
num_burnin <- 0
1015+
num_mcmc <- 1000
1016+
num_samples <- num_gfr + num_burnin + num_mcmc
1017+
bcf_model_warmstart <- bcf(
1018+
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
1019+
X_test = X_test, Z_test = Z_test, pi_test = pi_test,
1020+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
1021+
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F,
1022+
keep_vars_tau = c("x1")
1023+
)
1024+
```
1025+
1026+
Inspect the BART samples that were initialized with an XBART warm-start
1027+
1028+
```{r}
1029+
plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test,
1030+
xlab = "predicted", ylab = "actual", main = "Prognostic function")
1031+
abline(0,1,col="red",lty=3,lwd=3)
1032+
plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
1033+
xlab = "predicted", ylab = "actual", main = "Treatment effect")
1034+
abline(0,1,col="red",lty=3,lwd=3)
1035+
sigma_observed <- var(y-E_XZ)
1036+
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)),
1037+
max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
1038+
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds,
1039+
ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1040+
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1041+
```
1042+
1043+
Examine test set interval coverage
1044+
1045+
```{r}
1046+
test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
1047+
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
1048+
cover <- (
1049+
(test_lb <= tau_x[test_inds]) &
1050+
(test_ub >= tau_x[test_inds])
1051+
)
1052+
mean(cover)
1053+
```
1054+
1055+
And test set RMSE
1056+
1057+
```{r}
1058+
test_mean <- apply(bcf_model_warmstart$tau_hat_test, 1, mean)
1059+
sqrt(mean((tau_x[test_inds] - test_mean)^2))
1060+
```
1061+
7951062
# Continuous Treatment
7961063

7971064
## Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect

0 commit comments

Comments
 (0)