Skip to content

Commit f55bbb4

Browse files
authored
Merge pull request #166 from StochasticTree/bcf-heteroskedasticity-hotfix
Fixing bug in heteroskedastic BCF and standardizing the use of variance instead of standard deviation throughout the interface
2 parents 6dae5cb + e880e46 commit f55bbb4

16 files changed

+951
-605
lines changed

R/bart.R

Lines changed: 85 additions & 85 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 155 additions & 155 deletions
Large diffs are not rendered by default.

R/kernel.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
157157
if (!model_object$model_params$include_mean_forest) {
158158
stop("Mean forest was not sampled in the bart model provided")
159159
}
160-
if (!model_object$model_params$sample_sigma_leaf) {
160+
if (!model_object$model_params$sample_sigma2_leaf) {
161161
stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided")
162162
}
163163
leaf_scale_vector <- model_object$sigma2_leaf_samples
@@ -170,15 +170,15 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
170170
} else {
171171
stopifnot(forest_type %in% c("prognostic", "treatment", "variance"))
172172
if (forest_type=="prognostic") {
173-
if (!model_object$model_params$sample_sigma_leaf_mu) {
173+
if (!model_object$model_params$sample_sigma2_leaf_mu) {
174174
stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided")
175175
}
176-
leaf_scale_vector <- model_object$sigma_leaf_mu_samples
176+
leaf_scale_vector <- model_object$sigma2_leaf_mu_samples
177177
} else if (forest_type=="treatment") {
178-
if (!model_object$model_params$sample_sigma_leaf_tau) {
178+
if (!model_object$model_params$sample_sigma2_leaf_tau) {
179179
stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided")
180180
}
181-
leaf_scale_vector <- model_object$sigma_leaf_tau_samples
181+
leaf_scale_vector <- model_object$sigma2_leaf_tau_samples
182182
} else if (forest_type=="variance") {
183183
if (!model_object$model_params$include_variance_forest) {
184184
stop("Variance forest was not sampled in the bcf model provided")

R/serialization.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx
545545
#' Load a vector from json
546546
#'
547547
#' @param json_object Object of class `CppJson`
548-
#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_samples") in the overall json hierarchy
548+
#' @param json_vector_label Label referring to a particular vector (i.e. "sigma2_global_samples") in the overall json hierarchy
549549
#' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which vector sits
550550
#'
551551
#' @return R vector

R/stochtree-package.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
## usethis namespace: start
22
#' @importFrom stats coef
3+
#' @importFrom stats dnorm
34
#' @importFrom stats lm
45
#' @importFrom stats model.matrix
56
#' @importFrom stats predict
67
#' @importFrom stats qgamma
8+
#' @importFrom stats qnorm
9+
#' @importFrom stats pnorm
710
#' @importFrom stats resid
811
#' @importFrom stats rnorm
12+
#' @importFrom stats runif
913
#' @importFrom stats sd
1014
#' @importFrom stats sigma
1115
#' @importFrom stats var

demo/debug/causal_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
8888
plt.show()
8989

90-
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
91-
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
90+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
91+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
9292
plt.show()
9393

9494
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])

demo/debug/supervised_learning.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def outcome_mean(X, W):
6666
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
6767
plt.show()
6868

69-
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
70-
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
69+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
70+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
7171
plt.show()
7272

7373
# Compute the test set RMSE
@@ -89,8 +89,8 @@ def outcome_mean(X, W):
8989
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
9090
plt.show()
9191

92-
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
93-
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
92+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
93+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
9494
plt.show()
9595

9696
# Compute the test set RMSE
@@ -110,8 +110,8 @@ def outcome_mean(X, W):
110110
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
111111
plt.show()
112112

113-
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
114-
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
113+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples),axis=1), np.expand_dims(bart_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma^2"])
114+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma^2")
115115
plt.show()
116116

117117
# Compute the test set RMSE

demo/notebooks/prototype_interface.ipynb

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,9 @@
346346
"forest_preds_mcmc = forest_preds[:, num_warmstart:num_samples]\n",
347347
"\n",
348348
"# Global error variance\n",
349-
"sigma_samples = np.sqrt(global_var_samples) * y_std\n",
350-
"sigma_samples_gfr = sigma_samples[:num_warmstart]\n",
351-
"sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]"
349+
"sigma2_samples = global_var_samples * y_std * y_std\n",
350+
"sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n",
351+
"sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]"
352352
]
353353
},
354354
{
@@ -384,13 +384,13 @@
384384
" np.concatenate(\n",
385385
" (\n",
386386
" np.expand_dims(np.arange(num_warmstart), axis=1),\n",
387-
" np.expand_dims(sigma_samples_gfr, axis=1),\n",
387+
" np.expand_dims(sigma2_samples_gfr, axis=1),\n",
388388
" ),\n",
389389
" axis=1,\n",
390390
" ),\n",
391-
" columns=[\"Sample\", \"Sigma\"],\n",
391+
" columns=[\"Sample\", \"Sigma^2\"],\n",
392392
")\n",
393-
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n",
393+
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n",
394394
"plt.show()"
395395
]
396396
},
@@ -427,13 +427,13 @@
427427
" np.concatenate(\n",
428428
" (\n",
429429
" np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n",
430-
" np.expand_dims(sigma_samples_mcmc, axis=1),\n",
430+
" np.expand_dims(sigma2_samples_mcmc, axis=1),\n",
431431
" ),\n",
432432
" axis=1,\n",
433433
" ),\n",
434-
" columns=[\"Sample\", \"Sigma\"],\n",
434+
" columns=[\"Sample\", \"Sigma^2\"],\n",
435435
")\n",
436-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
436+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
437437
"plt.show()"
438438
]
439439
},
@@ -909,9 +909,9 @@
909909
"forest_preds_tau_mcmc = forest_preds_tau[:, num_warmstart:num_samples]\n",
910910
"\n",
911911
"# Global error variance\n",
912-
"sigma_samples = np.sqrt(global_var_samples) * y_std\n",
913-
"sigma_samples_gfr = sigma_samples[:num_warmstart]\n",
914-
"sigma_samples_mcmc = sigma_samples[num_warmstart:num_samples]\n",
912+
"sigma2_samples = global_var_samples * y_std * y_std\n",
913+
"sigma2_samples_gfr = sigma2_samples[:num_warmstart]\n",
914+
"sigma2_samples_mcmc = sigma2_samples[num_warmstart:num_samples]\n",
915915
"\n",
916916
"# Adaptive coding parameters\n",
917917
"b_1_samples_gfr = b_1_samples[:num_warmstart] * y_std\n",
@@ -969,13 +969,13 @@
969969
" np.concatenate(\n",
970970
" (\n",
971971
" np.expand_dims(np.arange(num_warmstart), axis=1),\n",
972-
" np.expand_dims(sigma_samples_gfr, axis=1),\n",
972+
" np.expand_dims(sigma2_samples_gfr, axis=1),\n",
973973
" ),\n",
974974
" axis=1,\n",
975975
" ),\n",
976-
" columns=[\"Sample\", \"Sigma\"],\n",
976+
" columns=[\"Sample\", \"Sigma^2\"],\n",
977977
")\n",
978-
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma\")\n",
978+
"sns.scatterplot(data=sigma_df_gfr, x=\"Sample\", y=\"Sigma^2\")\n",
979979
"plt.show()"
980980
]
981981
},
@@ -1050,13 +1050,13 @@
10501050
" np.concatenate(\n",
10511051
" (\n",
10521052
" np.expand_dims(np.arange(num_samples - num_warmstart), axis=1),\n",
1053-
" np.expand_dims(sigma_samples_mcmc, axis=1),\n",
1053+
" np.expand_dims(sigma2_samples_mcmc, axis=1),\n",
10541054
" ),\n",
10551055
" axis=1,\n",
10561056
" ),\n",
1057-
" columns=[\"Sample\", \"Sigma\"],\n",
1057+
" columns=[\"Sample\", \"Sigma^2\"],\n",
10581058
")\n",
1059-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
1059+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
10601060
"plt.show()"
10611061
]
10621062
},

demo/notebooks/serialization.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@
173173
" ),\n",
174174
" axis=1,\n",
175175
" ),\n",
176-
" columns=[\"Sample\", \"Sigma\"],\n",
176+
" columns=[\"Sample\", \"Sigma^2\"],\n",
177177
")\n",
178-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
178+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
179179
"plt.show()"
180180
]
181181
},

demo/notebooks/supervised_learning.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@
172172
" ),\n",
173173
" axis=1,\n",
174174
" ),\n",
175-
" columns=[\"Sample\", \"Sigma\"],\n",
175+
" columns=[\"Sample\", \"Sigma^2\"],\n",
176176
")\n",
177-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
177+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
178178
"plt.show()"
179179
]
180180
},
@@ -260,9 +260,9 @@
260260
" ),\n",
261261
" axis=1,\n",
262262
" ),\n",
263-
" columns=[\"Sample\", \"Sigma\"],\n",
263+
" columns=[\"Sample\", \"Sigma^2\"],\n",
264264
")\n",
265-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
265+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
266266
"plt.show()"
267267
]
268268
},
@@ -346,9 +346,9 @@
346346
" ),\n",
347347
" axis=1,\n",
348348
" ),\n",
349-
" columns=[\"Sample\", \"Sigma\"],\n",
349+
" columns=[\"Sample\", \"Sigma^2\"],\n",
350350
")\n",
351-
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma\")\n",
351+
"sns.scatterplot(data=sigma_df_mcmc, x=\"Sample\", y=\"Sigma^2\")\n",
352352
"plt.show()"
353353
]
354354
},
@@ -371,7 +371,7 @@
371371
],
372372
"metadata": {
373373
"kernelspec": {
374-
"display_name": "venv",
374+
"display_name": "stochtree-dev",
375375
"language": "python",
376376
"name": "python3"
377377
},
@@ -385,7 +385,7 @@
385385
"name": "python",
386386
"nbconvert_exporter": "python",
387387
"pygments_lexer": "ipython3",
388-
"version": "3.12.9"
388+
"version": "3.10.16"
389389
}
390390
},
391391
"nbformat": 4,

0 commit comments

Comments
 (0)