@@ -792,6 +792,273 @@ mean(cover)
792
792
793
793
It is clear that causal inference is much more difficult in the presence of ** both** strong covariate-dependent prognostic effects and strong group-level random effects. In this sense, proper prior calibration for all three of the $\mu$, $\tau$ and random effects models is crucial.
794
794
795
+ ## Demo 6: Nonlinear Outcome Model, Heterogeneous Treatment Effect using Different Features in the Prognostic and Treatment Forests
796
+
797
+ Here, we consider the case in which we might prefer to use only a subset of covariates in the treatment effect forest. Why might we want to do that?
798
+ Well, in many cases it is plausible that some covariates (for example age, income, etc...) influence the outcome of interest in a causal problem, but do not ** moderate** the treatment effect. In this case, we'd need to include these variables in the prognostic forest for deconfounding but we don't necessarily need to include them in the treatment effect forest.
799
+
800
+ ### Simulation
801
+
802
+ We draw from a modified "demo 1" DGP
803
+
804
+ ``` {r}
805
+ mu <- function(x) {1+g(x)+x[,1]*x[,3]-x[,2]+3*x[,3]}
806
+ tau <- function(x) {1+0.5*abs(x[,1])-0.25*sin(2*x[,1])}
807
+ n <- 500
808
+ snr <- 2
809
+ x1 <- rnorm(n)
810
+ x2 <- rnorm(n)
811
+ x3 <- rnorm(n)
812
+ x4 <- as.numeric(rbinom(n,1,0.5))
813
+ x5 <- as.numeric(sample(1:3,n,replace=TRUE))
814
+ X <- cbind(x1,x2,x3,x4,x5)
815
+ p <- ncol(X)
816
+ mu_x <- mu(X)
817
+ tau_x <- tau(X)
818
+ pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
819
+ Z <- rbinom(n,1,pi_x)
820
+ E_XZ <- mu_x + Z*tau_x
821
+ y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
822
+ X <- as.data.frame(X)
823
+ X$x4 <- factor(X$x4, ordered = TRUE)
824
+ X$x5 <- factor(X$x5, ordered = TRUE)
825
+
826
+ # Split data into test and train sets
827
+ test_set_pct <- 0.2
828
+ n_test <- round(test_set_pct*n)
829
+ n_train <- n - n_test
830
+ test_inds <- sort(sample(1:n, n_test, replace = FALSE))
831
+ train_inds <- (1:n)[!((1:n) %in% test_inds)]
832
+ X_test <- X[test_inds,]
833
+ X_train <- X[train_inds,]
834
+ pi_test <- pi_x[test_inds]
835
+ pi_train <- pi_x[train_inds]
836
+ Z_test <- Z[test_inds]
837
+ Z_train <- Z[train_inds]
838
+ y_test <- y[test_inds]
839
+ y_train <- y[train_inds]
840
+ mu_test <- mu_x[test_inds]
841
+ mu_train <- mu_x[train_inds]
842
+ tau_test <- tau_x[test_inds]
843
+ tau_train <- tau_x[train_inds]
844
+ ```
845
+
846
+ ### Sampling and Analysis
847
+
848
+ #### MCMC, full covariate set in $\tau(X)$
849
+
850
+ Here we simulate from the model with the original MCMC sampler, using all of the covariates in both the prognostic ($\mu(X)$) and treatment effect ($\tau(X)$) forests.
851
+
852
+ ``` {r}
853
+ num_gfr <- 0
854
+ num_burnin <- 1000
855
+ num_mcmc <- 1000
856
+ num_samples <- num_gfr + num_burnin + num_mcmc
857
+ bcf_model_mcmc <- bcf(
858
+ X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
859
+ X_test = X_test, Z_test = Z_test, pi_test = pi_test,
860
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
861
+ sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
862
+ )
863
+ ```
864
+
865
+ Inspect the burned-in samples
866
+
867
+ ``` {r}
868
+ plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test,
869
+ xlab = "predicted", ylab = "actual", main = "Prognostic function")
870
+ abline(0,1,col="red",lty=3,lwd=3)
871
+ plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test,
872
+ xlab = "predicted", ylab = "actual", main = "Treatment effect")
873
+ abline(0,1,col="red",lty=3,lwd=3)
874
+ sigma_observed <- var(y-E_XZ)
875
+ plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)),
876
+ max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
877
+ plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds,
878
+ ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
879
+ abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
880
+ ```
881
+
882
+ Examine test set interval coverage
883
+
884
+ ``` {r}
885
+ test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
886
+ test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
887
+ cover <- (
888
+ (test_lb <= tau_x[test_inds]) &
889
+ (test_ub >= tau_x[test_inds])
890
+ )
891
+ mean(cover)
892
+ ```
893
+
894
+ And test set RMSE
895
+
896
+ ``` {r}
897
+ test_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
898
+ sqrt(mean((test_mean - tau_test)^2))
899
+ ```
900
+
901
+ #### MCMC, covariate subset in $\tau(X)$
902
+
903
+ Here we simulate from the model with the original MCMC sampler, using only covariate $X_1$ in the treatment effect forest.
904
+
905
+ ``` {r}
906
+ num_gfr <- 0
907
+ num_burnin <- 1000
908
+ num_mcmc <- 1000
909
+ num_samples <- num_gfr + num_burnin + num_mcmc
910
+ bcf_model_mcmc <- bcf(
911
+ X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
912
+ X_test = X_test, Z_test = Z_test, pi_test = pi_test,
913
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
914
+ sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F,
915
+ keep_vars_tau = c("x1")
916
+ )
917
+ ```
918
+
919
+ Inspect the BART samples
920
+
921
+ ``` {r}
922
+ plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test,
923
+ xlab = "predicted", ylab = "actual", main = "Prognostic function")
924
+ abline(0,1,col="red",lty=3,lwd=3)
925
+ plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test,
926
+ xlab = "predicted", ylab = "actual", main = "Treatment effect")
927
+ abline(0,1,col="red",lty=3,lwd=3)
928
+ sigma_observed <- var(y-E_XZ)
929
+ plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)),
930
+ max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
931
+ plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds,
932
+ ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
933
+ abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
934
+ ```
935
+
936
+ Examine test set interval coverage
937
+
938
+ ``` {r}
939
+ test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
940
+ test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
941
+ cover <- (
942
+ (test_lb <= tau_x[test_inds]) &
943
+ (test_ub >= tau_x[test_inds])
944
+ )
945
+ mean(cover)
946
+ ```
947
+
948
+ And test set RMSE
949
+
950
+ ``` {r}
951
+ test_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
952
+ sqrt(mean((test_mean - tau_test)^2))
953
+ ```
954
+
955
+ #### Warmstart, full covariate set in $\tau(X)$
956
+
957
+ Here we simulate from the model with the warm-start sampler, using all of the covariates in both the prognostic ($\mu(X)$) and treatment effect ($\tau(X)$) forests.
958
+
959
+ ``` {r}
960
+ num_gfr <- 10
961
+ num_burnin <- 0
962
+ num_mcmc <- 1000
963
+ num_samples <- num_gfr + num_burnin + num_mcmc
964
+ bcf_model_warmstart <- bcf(
965
+ X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
966
+ X_test = X_test, Z_test = Z_test, pi_test = pi_test,
967
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
968
+ sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
969
+ )
970
+ ```
971
+
972
+ Inspect the BART samples that were initialized with an XBART warm-start
973
+
974
+ ``` {r}
975
+ plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test,
976
+ xlab = "predicted", ylab = "actual", main = "Prognostic function")
977
+ abline(0,1,col="red",lty=3,lwd=3)
978
+ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
979
+ xlab = "predicted", ylab = "actual", main = "Treatment effect")
980
+ abline(0,1,col="red",lty=3,lwd=3)
981
+ sigma_observed <- var(y-E_XZ)
982
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)),
983
+ max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
984
+ plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds,
985
+ ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
986
+ abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
987
+ ```
988
+
989
+ Examine test set interval coverage
990
+
991
+ ``` {r}
992
+ test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
993
+ test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
994
+ cover <- (
995
+ (test_lb <= tau_x[test_inds]) &
996
+ (test_ub >= tau_x[test_inds])
997
+ )
998
+ mean(cover)
999
+ ```
1000
+
1001
+ And test set RMSE
1002
+
1003
+ ``` {r}
1004
+ test_mean <- apply(bcf_model_warmstart$tau_hat_test, 1, mean)
1005
+ sqrt(mean((tau_x[test_inds] - test_mean)^2))
1006
+ ```
1007
+
1008
+ #### Warmstart, covariate subset in $\tau(X)$
1009
+
1010
+ Here we simulate from the model with the warm-start sampler, using only covariate $X_1$ in the treatment effect forest.
1011
+
1012
+ ``` {r}
1013
+ num_gfr <- 10
1014
+ num_burnin <- 0
1015
+ num_mcmc <- 1000
1016
+ num_samples <- num_gfr + num_burnin + num_mcmc
1017
+ bcf_model_warmstart <- bcf(
1018
+ X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
1019
+ X_test = X_test, Z_test = Z_test, pi_test = pi_test,
1020
+ num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
1021
+ sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F,
1022
+ keep_vars_tau = c("x1")
1023
+ )
1024
+ ```
1025
+
1026
+ Inspect the BART samples that were initialized with an XBART warm-start
1027
+
1028
+ ``` {r}
1029
+ plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test,
1030
+ xlab = "predicted", ylab = "actual", main = "Prognostic function")
1031
+ abline(0,1,col="red",lty=3,lwd=3)
1032
+ plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test,
1033
+ xlab = "predicted", ylab = "actual", main = "Treatment effect")
1034
+ abline(0,1,col="red",lty=3,lwd=3)
1035
+ sigma_observed <- var(y-E_XZ)
1036
+ plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)),
1037
+ max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
1038
+ plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds,
1039
+ ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
1040
+ abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")
1041
+ ```
1042
+
1043
+ Examine test set interval coverage
1044
+
1045
+ ``` {r}
1046
+ test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
1047
+ test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
1048
+ cover <- (
1049
+ (test_lb <= tau_x[test_inds]) &
1050
+ (test_ub >= tau_x[test_inds])
1051
+ )
1052
+ mean(cover)
1053
+ ```
1054
+
1055
+ And test set RMSE
1056
+
1057
+ ``` {r}
1058
+ test_mean <- apply(bcf_model_warmstart$tau_hat_test, 1, mean)
1059
+ sqrt(mean((tau_x[test_inds] - test_mean)^2))
1060
+ ```
1061
+
795
1062
# Continuous Treatment
796
1063
797
1064
## Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect
0 commit comments