diff --git a/R/bcf.R b/R/bcf.R index b9842c5d..cb441d36 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1104,6 +1104,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id global_model_config$update_global_error_variance(current_sigma2) } } else if (has_prev_model) { + if (adaptive_coding) { + if (!is.null(previous_b_1_samples)) { + current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num] + } + if (!is.null(previous_b_0_samples)) { + current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num] + } + tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + } resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) resetActiveForest(active_forest_tau, previous_forest_samples_tau, previous_model_warmstart_sample_num - 1) @@ -1122,21 +1137,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } - if (adaptive_coding) { - if (!is.null(previous_b_1_samples)) { - current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num] - } - if (!is.null(previous_b_0_samples)) { - current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num] - } - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) - } if (has_rfx) { if (is.null(previous_rfx_samples)) { warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") @@ -1618,6 +1618,8 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU # Add propensities to covariate set if necessary if (object$model_params$propensity_covariate != "none") { X_combined <- cbind(X, propensity) + } else { + X_combined <- X } # Create prediction datasets diff --git a/tools/debug/bart_continue_sampler_debug.R b/tools/debug/bart_continue_sampler_debug.R new file mode 100644 index 00000000..b921f979 --- /dev/null +++ b/tools/debug/bart_continue_sampler_debug.R @@ -0,0 +1,84 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p_x <- 10 +snr <- 2 +X <- matrix(runif(n*p_x), ncol = p_x) +f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4]) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +y_test <- y[test_inds] +y_train <- y[train_inds] +f_XW_test <- f_XW[test_inds] +f_XW_train <- f_XW[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 2.0, max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0/num_trees) +xbart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, + general_params = general_params, + mean_forest_params = mean_forest_params +) + +# Inspect results +plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(xbart_model$sigma2_global_samples) +xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) + +# Run the BART MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 2.0, max_depth = -1, + min_samples_leaf = 1, + sample_sigma2_leaf = F, + sigma2_leaf_init = 1.0/num_trees) +bart_model <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params, + previous_model_json = xbart_model_string, + previous_model_warmstart_sample_num = num_gfr +) + +# Inspect the results +plot(rowMeans(bart_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(bart_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bart_model_root <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +plot(rowMeans(bart_model_root$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bart_model_root$y_hat_test) - y_test)^2)), "\n")) +cat(paste0("Interval coverage = ", mean((apply(bart_model_root$y_hat_test, 1, quantile, probs=0.025) <= f_XW_test) & (apply(bart_model_root$y_hat_test, 1, quantile, probs=0.975) >= f_XW_test)), "\n")) +plot(bart_model_root$sigma2_global_samples) diff --git a/tools/debug/bcf_401k_data_debug.R b/tools/debug/bcf_401k_data_debug.R new file mode 100644 index 00000000..29cd478e --- /dev/null +++ b/tools/debug/bcf_401k_data_debug.R @@ -0,0 +1,210 @@ +################################################################################ +## Investigation of GFR vs MCMC fit issues on the 401k dataset +################################################################################ + +# Load libraries and set seed +library(stochtree) +library(DoubleML) +library(BART) +library(tidyverse) +# seed = 102 +# set.seed(seed) + +# Load 401k data +dat = DoubleML::fetch_401k(return_type = "data.frame") +dat_orig = dat + +# Trim outliers +dat = dat %>% filter(abs(inc)% dplyr::select(-c(e401, net_tfa)) + +# Convert to df and define categorical data types +xdf = data.frame(x) +xdf_st = xdf %>% + mutate(age=factor(age, ordered=TRUE), + inc = factor(inc, ordered=TRUE), + educ = factor(educ, ordered=TRUE), + fsize = factor(fsize, ordered=TRUE), + marr=factor(marr, ordered=TRUE), + twoearn=factor(twoearn, ordered=TRUE), + db=factor(db, ordered=TRUE), + pira=factor(pira, ordered=TRUE), + hown=factor(hown, ordered=TRUE)) + +# Isolate treatment and outcome +z = dat %>% dplyr::select(e401) %>% as.matrix() +y = dat %>% dplyr::select(net_tfa) %>% as.matrix() + +# Define a "jittered" version of the original (integer-valued) x columns +# in which all categories are "upper-jittered" with uniform [0, eps] noise +# except for the largest category which is "lower-jittered" with [-eps, 0] noise +x_jitter = x +for (j in 1:ncol(x)) { + min_diff <- min(diff(sort(x[,j]))[diff(sort(x[,j])) > 0]) + jitter_param <- min_diff / 3.0 + has_max_category <- x[,j] == max(x[,j]) + x_jitter[has_max_category,j] <- x[has_max_category,j] + runif(sum(has_max_category), -jitter_param, 0.0) + x_jitter[!has_max_category,j] <- x[!has_max_category,j] + runif(sum(!has_max_category), 0.0, jitter_param) +} +# Visualize jitters +# for (j in 1:ncol(x)) { +# plot(x[,j], x_jitter[,j], ylab = "jittered", xlab = "original") +# unique_xs <- unique(x[,j]) +# for (i in unique_xs) { +# abline(h = unique_xs[i], col = "red", lty = 3) +# } +# } + +# Fit a p(z = 1 | x) model for propensity features +ps_fit = pbart(x.train = xdf, + y.train = z, ntree = 200, numcut=1000, ndpost = 100, + usequants = TRUE, k = 2.0, nskip = 100, keepevery=1) +g = colMeans(pnorm(ps_fit$yhat.train)) +psf = pnorm(ps_fit$yhat.train) + +# Test-train split +n <- nrow(x) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +xdf_st_test <- xdf_st[test_inds,] +xdf_st_train <- xdf_st[train_inds,] +x_test <- x[test_inds,] +x_train <- x[train_inds,] +x_jitter_test <- x_jitter[test_inds,] +x_jitter_train <- x_jitter[train_inds,] +pi_test <- g[test_inds] +pi_train <- g[train_inds] +z_test <- z[test_inds,] +z_train <- z[train_inds,] +y_test <- y[test_inds,] +y_train <- y[train_inds,] +y_train_scale <- scale(y_train) +y_train_sd <- attr(y_train_scale, "scaled:scale") +y_train_mean <- attr(y_train_scale, "scaled:center") +y_test_scale <- (y_test - y_train_mean) / y_train_sd +var(y_train_scale) +var(y_test_scale) + +# Fit BCF with GFR algorithm on the jittered covariates +# and save model to JSON +num_gfr <- 1000 +general_params <- list( + adaptive_coding = FALSE, propensity_covariate = "none", + keep_every = 1, verbose = TRUE, keep_gfr = TRUE +) +bcf_model_gfr <- stochtree::bcf( + X_train = xdf_st_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = xdf_st_test, Z_test = c(z_test), + propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, + num_mcmc = 0, general_params = general_params +) +fit_json_gfr = saveBCFModelToJsonString(bcf_model_gfr) + +# Run MCMC chain from the last GFR sample, setting covariate +# equal to an interpolation between the original x and x_jitter +# (alpha = 0 is 100% x_jitter and alpha = 1 is 100% x) +# alpha <- 1.0 +# x_jitter_new_train <- (alpha) * x_train + (1-alpha) * x_jitter_train +# x_jitter_new_test <- (alpha) * x_test + (1-alpha) * x_jitter_test +x_jitter_new_train <- xdf_st_train +x_jitter_new_test <- xdf_st_test +num_mcmc <- 10000 +bcf_model_mcmc <- stochtree::bcf( + X_train = x_jitter_new_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = x_jitter_new_test, Z_test = c(z_test), + propensity_test = pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + previous_model_json = fit_json_gfr, + previous_model_warmstart_sample_num = num_gfr, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +combined_sigma <- c(bcf_model_gfr$sigma2_global_samples, + bcf_model_mcmc$sigma2_global_samples) +plot(combined_sigma, ylab = "sigma2", xlab = "sample num", + main = "Global error var traceplot") + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_gfr$y_hat_train, + bcf_model_mcmc$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_gfr$y_hat_test, + bcf_model_mcmc$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot(test_mses, ylab = "outcome MSE", xlab = "sample num", + main = "Outcome MSE Traceplot", ylim = c(min_y, max_y)) +points(train_mses, col = "blue") +legend("right", legend = c("Out-of-Sample", "In-Sample"), + col = c("black", "blue"), pch = c(1,1)) + +# Run some one-off pred vs actual plots +plot(yhat_combined[,11000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_mcmc$y_hat_train[,10000], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_mcmc$y_hat_test[,10000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_train[,1000], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_test[,1000], y_test_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_train[,10], y_train_scale); abline(0,1,col="red",lty=3) +plot(bcf_model_gfr$y_hat_test[,10], y_test_scale); abline(0,1,col="red",lty=3) + +# Run MCMC chain from root +num_mcmc <- 10000 +bcf_model_mcmc_root <- stochtree::bcf( + X_train = xdf_st_train, Z_train = c(z_train), + y_train = c(y_train_scale), propensity_train = pi_train, + X_test = xdf_st_test, Z_test = c(z_test), + propensity_test = pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = general_params +) + +# Inspect the "in-sample sigma" via the traceplot +# of the global error variance parameter +sigma_trace <- bcf_model_mcmc_root$sigma2_global_samples +plot(sigma_trace, ylab = "sigma2", xlab = "sample num", + main = "Global error var traceplot") + +# Inspect the "out-of-sample sigma" by compute the MSE +# of the yhat on the test set +yhat_combined_train <- cbind( + bcf_model_mcmc_root$y_hat_train +) +yhat_combined_test <- cbind( + bcf_model_mcmc_root$y_hat_test +) +num_samples <- ncol(yhat_combined_train) +train_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + train_mses[i] <- mean((yhat_combined_train[,i] - y_train_scale)^2) +} +test_mses <- rep(NA, num_samples) +for (i in 1:num_samples) { + test_mses[i] <- mean((yhat_combined_test[,i] - y_test_scale)^2) +} +max_y <- max(c(max(train_mses, test_mses))) +min_y <- min(c(min(train_mses, test_mses))) +plot(test_mses, ylab = "outcome MSE", xlab = "sample num", + main = "Test set outcome MSEs", ylim = c(min_y, max_y)) +points(train_mses, col = "blue") diff --git a/tools/debug/bcf_continue_sampler_debug.R b/tools/debug/bcf_continue_sampler_debug.R new file mode 100644 index 00000000..1c46db5c --- /dev/null +++ b/tools/debug/bcf_continue_sampler_debug.R @@ -0,0 +1,92 @@ +# Load libraries +library(stochtree) + +# Sampler settings +num_chains <- 1 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 20 +num_trees <- 100 + +# Generate the data +n <- 500 +p <- 5 +snr <- 2 +X <- matrix(runif(n*p), ncol = p) +mu_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +pi_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +) +tau_x <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +) +Z <- rbinom(n, 1, pi_x) +f_XZ <- mu_x + tau_x*Z +noise_sd <- sd(f_XZ) / snr +y <- f_XZ + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Run the GFR algorithm +general_params <- list(sample_sigma2_global = T) +xbcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = num_gfr, num_burnin = 0, + num_mcmc = 0, general_params = general_params) + +# Inspect results +plot(rowMeans(xbcf_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(xbcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(xbcf_model$sigma2_global_samples) +xbcf_model_string <- stochtree::saveBCFModelToJsonString(xbcf_model) + +# Run the BCF MCMC sampler, initialized from the XBART sampler +general_params <- list(sample_sigma2_global = T) +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, + num_mcmc = num_mcmc, general_params = general_params, + previous_model_json = xbcf_model_string, + previous_model_warmstart_sample_num = num_gfr) + +# Inspect the results +plot(rowMeans(bcf_model$y_hat_test), y_test); abline(0,1) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model$y_hat_test) - y_test)^2)), "\n")) +plot(bcf_model$sigma2_global_samples) + +# Compare to a single chain of MCMC samples initialized at root +bcf_model_root <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = num_burnin, + num_mcmc = num_mcmc, general_params = general_params) +plot(rowMeans(bcf_model_root$y_hat_test), y_test); abline(0,1) +plot(bcf_model_root$sigma2_global_samples) +cat(paste0("RMSE = ", sqrt(mean((rowMeans(bcf_model_root$y_hat_test) - y_test)^2)), "\n"))