Skip to content

Commit a65f606

Browse files
committed
Updated docs and incorporated CRAN check feedback
1 parent e64a104 commit a65f606

17 files changed

+120
-62
lines changed

R/bart.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,8 +1133,10 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
11331133
#' rfx_term_test <- rfx_term[test_inds]
11341134
#' rfx_term_train <- rfx_term[train_inds]
11351135
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
1136-
#' rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test,
1137-
#' rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test,
1136+
#' rfx_group_ids_train = rfx_group_ids_train,
1137+
#' rfx_group_ids_test = rfx_group_ids_test,
1138+
#' rfx_basis_train = rfx_basis_train,
1139+
#' rfx_basis_test = rfx_basis_test,
11381140
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100)
11391141
#' rfx_samples <- getRandomEffectSamples(bart_model)
11401142
getRandomEffectSamples.bartmodel <- function(object, ...){
@@ -1190,6 +1192,10 @@ getRandomEffectSamples.bartmodel <- function(object, ...){
11901192
saveBARTModelToJson <- function(object){
11911193
jsonobj <- createCppJson()
11921194

1195+
if (!inherits(object, "bartmodel")) {
1196+
stop("`object` must be a BART model")
1197+
}
1198+
11931199
if (is.null(object$model_params)) {
11941200
stop("This BCF model has not yet been sampled")
11951201
}

R/bcf.R

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
14091409
#' mu_train <- mu_x[train_inds]
14101410
#' tau_test <- tau_x[test_inds]
14111411
#' tau_train <- tau_x[train_inds]
1412-
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train)
1412+
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
1413+
#' propensity_train = pi_train)
14131414
#' preds <- predict(bcf_model, X_test, Z_test, pi_test)
14141415
#' plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted",
14151416
#' ylab = "actual", main = "Prognostic function")
@@ -1597,9 +1598,11 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
15971598
#' mu_params <- list(sample_sigma_leaf = TRUE)
15981599
#' tau_params <- list(sample_sigma_leaf = FALSE)
15991600
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
1600-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
1601+
#' propensity_train = pi_train,
1602+
#' rfx_group_ids_train = rfx_group_ids_train,
16011603
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
1602-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
1604+
#' Z_test = Z_test, propensity_test = pi_test,
1605+
#' rfx_group_ids_test = rfx_group_ids_test,
16031606
#' rfx_basis_test = rfx_basis_test,
16041607
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
16051608
#' mu_forest_params = mu_params,
@@ -1686,9 +1689,11 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){
16861689
#' mu_params <- list(sample_sigma_leaf = TRUE)
16871690
#' tau_params <- list(sample_sigma_leaf = FALSE)
16881691
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
1689-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
1692+
#' propensity_train = pi_train,
1693+
#' rfx_group_ids_train = rfx_group_ids_train,
16901694
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
1691-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
1695+
#' Z_test = Z_test, propensity_test = pi_test,
1696+
#' rfx_group_ids_test = rfx_group_ids_test,
16921697
#' rfx_basis_test = rfx_basis_test,
16931698
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
16941699
#' mu_forest_params = mu_params,
@@ -1697,7 +1702,7 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){
16971702
saveBCFModelToJson <- function(object){
16981703
jsonobj <- createCppJson()
16991704

1700-
if (class(object) != "bcfmodel") {
1705+
if (!inherits(object, "bcfmodel")) {
17011706
stop("`object` must be a BCF model")
17021707
}
17031708

@@ -1849,9 +1854,11 @@ saveBCFModelToJson <- function(object){
18491854
#' mu_params <- list(sample_sigma_leaf = TRUE)
18501855
#' tau_params <- list(sample_sigma_leaf = FALSE)
18511856
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
1852-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
1857+
#' propensity_train = pi_train,
1858+
#' rfx_group_ids_train = rfx_group_ids_train,
18531859
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
1854-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
1860+
#' Z_test = Z_test, propensity_test = pi_test,
1861+
#' rfx_group_ids_test = rfx_group_ids_test,
18551862
#' rfx_basis_test = rfx_basis_test,
18561863
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
18571864
#' mu_forest_params = mu_params,
@@ -2003,9 +2010,11 @@ saveBCFModelToJsonString <- function(object){
20032010
#' mu_params <- list(sample_sigma_leaf = TRUE)
20042011
#' tau_params <- list(sample_sigma_leaf = FALSE)
20052012
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
2006-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
2013+
#' propensity_train = pi_train,
2014+
#' rfx_group_ids_train = rfx_group_ids_train,
20072015
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
2008-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
2016+
#' Z_test = Z_test, propensity_test = pi_test,
2017+
#' rfx_group_ids_test = rfx_group_ids_test,
20092018
#' rfx_basis_test = rfx_basis_test,
20102019
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
20112020
#' mu_forest_params = mu_params,
@@ -2166,9 +2175,11 @@ createBCFModelFromJson <- function(json_object){
21662175
#' mu_params <- list(sample_sigma_leaf = TRUE)
21672176
#' tau_params <- list(sample_sigma_leaf = FALSE)
21682177
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
2169-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
2178+
#' propensity_train = pi_train,
2179+
#' rfx_group_ids_train = rfx_group_ids_train,
21702180
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
2171-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
2181+
#' Z_test = Z_test, propensity_test = pi_test,
2182+
#' rfx_group_ids_test = rfx_group_ids_test,
21722183
#' rfx_basis_test = rfx_basis_test,
21732184
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
21742185
#' mu_forest_params = mu_params,
@@ -2245,9 +2256,11 @@ createBCFModelFromJsonFile <- function(json_filename){
22452256
#' rfx_term_test <- rfx_term[test_inds]
22462257
#' rfx_term_train <- rfx_term[train_inds]
22472258
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
2248-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
2259+
#' propensity_train = pi_train,
2260+
#' rfx_group_ids_train = rfx_group_ids_train,
22492261
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
2250-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
2262+
#' Z_test = Z_test, propensity_test = pi_test,
2263+
#' rfx_group_ids_test = rfx_group_ids_test,
22512264
#' rfx_basis_test = rfx_basis_test,
22522265
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100)
22532266
#' # bcf_json <- saveBCFModelToJsonString(bcf_model)
@@ -2323,9 +2336,11 @@ createBCFModelFromJsonString <- function(json_string){
23232336
#' rfx_term_test <- rfx_term[test_inds]
23242337
#' rfx_term_train <- rfx_term[train_inds]
23252338
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
2326-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
2339+
#' propensity_train = pi_train,
2340+
#' rfx_group_ids_train = rfx_group_ids_train,
23272341
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
2328-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
2342+
#' Z_test = Z_test, propensity_test = pi_test,
2343+
#' rfx_group_ids_test = rfx_group_ids_test,
23292344
#' rfx_basis_test = rfx_basis_test,
23302345
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100)
23312346
#' # bcf_json_list <- list(saveBCFModelToJson(bcf_model))
@@ -2533,9 +2548,11 @@ createBCFModelFromCombinedJson <- function(json_object_list){
25332548
#' rfx_term_test <- rfx_term[test_inds]
25342549
#' rfx_term_train <- rfx_term[train_inds]
25352550
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
2536-
#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train,
2551+
#' propensity_train = pi_train,
2552+
#' rfx_group_ids_train = rfx_group_ids_train,
25372553
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
2538-
#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test,
2554+
#' Z_test = Z_test, propensity_test = pi_test,
2555+
#' rfx_group_ids_test = rfx_group_ids_test,
25392556
#' rfx_basis_test = rfx_basis_test,
25402557
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100)
25412558
#' # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model))

R/forest.R

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -889,17 +889,24 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL
889889
#' outcome <- createOutcome(y)
890890
#' rng <- createCppRNG(1234)
891891
#' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2)
892-
#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p,
893-
#' alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth,
894-
#' variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size,
895-
#' leaf_model_type=leaf_model, leaf_model_scale=leaf_scale)
892+
#' forest_model_config <- createForestModelConfig(feature_types=feature_types,
893+
#' num_trees=num_trees, num_observations=n,
894+
#' num_features=p, alpha=alpha, beta=beta,
895+
#' min_samples_leaf=min_samples_leaf,
896+
#' max_depth=max_depth,
897+
#' variable_weights=variable_weights,
898+
#' cutpoint_grid_size=cutpoint_grid_size,
899+
#' leaf_model_type=leaf_model,
900+
#' leaf_model_scale=leaf_scale)
896901
#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
897902
#' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
898-
#' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
903+
#' forest_samples <- createForestSamples(num_trees, leaf_dimension,
904+
#' is_leaf_constant, is_exponentiated)
899905
#' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.)
900906
#' forest_model$sample_one_iteration(
901907
#' forest_dataset, outcome, forest_samples, active_forest,
902-
#' rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE
908+
#' rng, forest_model_config, global_model_config,
909+
#' keep_forest = TRUE, gfr = FALSE
903910
#' )
904911
#' resetActiveForest(active_forest, forest_samples, 0)
905912
#' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)

R/kernel.R

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
4949
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
5050
# Extract relevant forest container
51-
object_name <- class(model_object)[1]
52-
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
51+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
5352
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
5453
if (model_type == "bart") {
5554
stopifnot(forest_type %in% c("mean", "variance"))
@@ -143,8 +142,8 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
143142
#' computeForestLeafVariances(bart_model, "mean", c(1,3,5))
144143
computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) {
145144
# Extract relevant forest container
146-
stopifnot(class(model_object) %in% c("bartmodel", "bcfmodel"))
147-
model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf")
145+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"))))
146+
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf")
148147
if (model_type == "bart") {
149148
stopifnot(forest_type %in% c("mean", "variance"))
150149
if (forest_type=="mean") {
@@ -234,9 +233,8 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
234233
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
235234
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
236235
# Extract relevant forest container
237-
object_name <- class(model_object)[1]
238-
stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples"))
239-
model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples"))
236+
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
237+
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
240238
if (model_type == "bart") {
241239
stopifnot(forest_type %in% c("mean", "variance"))
242240
if (forest_type=="mean") {

R/model.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,10 @@ createCppRNG <- function(random_seed = -1){
205205
#' feature_types <- as.integer(rep(0, p))
206206
#' X <- matrix(runif(n*p), ncol = p)
207207
#' forest_dataset <- createForestDataset(X)
208-
#' forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p,
209-
#' num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf,
208+
#' forest_model_config <- createForestModelConfig(feature_types=feature_types,
209+
#' num_trees=num_trees, num_features=p,
210+
#' num_observations=n, alpha=alpha, beta=beta,
211+
#' min_samples_leaf=min_samples_leaf,
210212
#' max_depth=max_depth, leaf_model_type=1)
211213
#' global_model_config <- createGlobalModelConfig(global_error_variance=1.0)
212214
#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)

man/createBCFModelFromCombinedJson.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/createBCFModelFromCombinedJsonString.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/createBCFModelFromJson.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/createBCFModelFromJsonFile.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/createBCFModelFromJsonString.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)