Skip to content

Commit 2c8c6df

Browse files
committed
Updated predict functions for BART and BCF in R
1 parent 3051bd7 commit 2c8c6df

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

R/bart.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,9 +1172,9 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
11721172
if (object$model_params$include_variance_forest) {
11731173
if (object$model_params$sample_sigma2_global) {
11741174
sigma2_global_samples <- object$sigma2_global_samples
1175-
variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_global_samples[i]))
1175+
variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i])
11761176
} else {
1177-
variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std
1177+
variance_forest_predictions <- s_x_raw*sigma2_init*y_std*y_std
11781178
}
11791179
}
11801180

R/bcf.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,9 +1652,9 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
16521652
if (object$model_params$include_variance_forest) {
16531653
if (object$model_params$sample_sigma2_global) {
16541654
sigma2_global_samples <- object$sigma2_global_samples
1655-
variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_global_samples[i]))
1655+
variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i])
16561656
} else {
1657-
variance_forest_predictions <- sqrt(s_x_raw*initial_sigma2)*y_std
1657+
variance_forest_predictions <- s_x_raw*initial_sigma2*y_std*y_std
16581658
}
16591659
}
16601660

0 commit comments

Comments
 (0)