Skip to content

Commit 63edb57

Browse files
committed
Fixed several bugs in RFX sampler and improved R interface to RFX
1 parent b790783 commit 63edb57

File tree

6 files changed

+170
-22
lines changed

6 files changed

+170
-22
lines changed

R/bcf.R

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#' @param adaptive_coding Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: T.
5656
#' @param b_0 Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: -0.5.
5757
#' @param b_1 Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: 0.5.
58+
#' @param rfx_prior_var Prior (diagonals of the) covariance of the random effects model. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))`
5859
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
5960
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
6061
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
@@ -119,8 +120,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
119120
q = 0.9, sigma2 = NULL, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL,
120121
keep_vars_tau = NULL, drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50,
121122
num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T,
122-
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T,
123-
b_0 = -0.5, b_1 = 0.5, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
123+
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5,
124+
b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
124125
# Variable weight preprocessing (and initialization if necessary)
125126
if (is.null(variable_weights)) {
126127
variable_weights = rep(1/ncol(X_train), ncol(X_train))
@@ -294,6 +295,16 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
294295
}
295296
}
296297
}
298+
299+
# Random effects covariance prior
300+
if (has_rfx) {
301+
if (is.null(rfx_prior_var)) {
302+
rfx_prior_var <- rep(1, ncol(rfx_basis_train))
303+
} else {
304+
if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) stop("rfx_prior_var must be a numeric vector")
305+
if (length(rfx_prior_var) != ncol(rfx_basis_train)) stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)")
306+
}
307+
}
297308

298309
# Update variable weights
299310
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
@@ -342,7 +353,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
342353

343354
# Check whether treatment is binary (specifically 0-1 binary)
344355
binary_treatment <- length(unique(Z_train)) == 2
345-
if (!(all(sort(unique(Z_train)) == c(0,1)))) binary_treatment <- F
356+
if (binary_treatment) {
357+
unique_treatments <- sort(unique(Z_train))
358+
if (!(all(unique_treatments == c(0,1)))) binary_treatment <- F
359+
}
346360

347361
# Adaptive coding will be ignored for continuous / ordered categorical treatments
348362
if ((!binary_treatment) && (adaptive_coding)) {
@@ -413,16 +427,22 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
413427

414428
# Random effects prior parameters
415429
if (has_rfx) {
416-
if (num_rfx_components == 1) {
417-
alpha_init <- c(1)
418-
} else if (num_rfx_components > 1) {
419-
alpha_init <- c(1,rep(0,num_rfx_components-1))
420-
} else {
430+
# Initialize the working parameter to 1
431+
if (num_rfx_components < 1) {
421432
stop("There must be at least 1 random effect component")
422433
}
423-
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
434+
alpha_init <- rep(1,num_rfx_components)
435+
# Initialize each group parameter based on a regression of outcome on basis in that grou
436+
xi_init <- matrix(0,num_rfx_components,num_rfx_groups)
437+
for (i in 1:num_rfx_groups) {
438+
group_subset_indices <- group_ids_train == i
439+
basis_group <- rfx_basis_train[group_subset_indices,]
440+
resid_group <- resid_train[group_subset_indices]
441+
rfx_group_model <- lm(resid_group ~ 0+basis_group)
442+
xi_init[,i] <- unname(coef(rfx_group_model))
443+
}
424444
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
425-
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
445+
sigma_xi_init <- diag(rfx_prior_var)
426446
sigma_xi_shape <- 1
427447
sigma_xi_scale <- 1
428448
}

include/stochtree/random_effects.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class LabelMapper {
7575
auto pos = label_map_.find(category_id);
7676
return pos != label_map_.end();
7777
}
78-
bool CategoryNumber(int32_t category_id) {
78+
int32_t CategoryNumber(int32_t category_id) {
7979
return label_map_[category_id];
8080
}
8181
std::vector<int32_t>& Keys() {return keys_;}
@@ -99,7 +99,7 @@ class MultivariateRegressionRandomEffectsModel {
9999
working_parameter_ = Eigen::VectorXd(num_components_);
100100
group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
101101
group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
102-
working_parameter_covariance_ = Eigen::VectorXd(num_components_, num_components_);
102+
working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
103103
}
104104
~MultivariateRegressionRandomEffectsModel() {}
105105

@@ -206,7 +206,7 @@ class MultivariateRegressionRandomEffectsModel {
206206
tracker.SetPrediction(i, new_pred);
207207
}
208208
}
209-
private:
209+
210210
/*! \brief Compute the posterior mean of the working parameter, conditional on the group parameters and the variance components */
211211
Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
212212
/*! \brief Compute the posterior covariance of the working parameter, conditional on the group parameters and the variance components */
@@ -219,7 +219,8 @@ class MultivariateRegressionRandomEffectsModel {
219219
double VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
220220
/*! \brief Compute the posterior scale of the group variance component, conditional on the working and group parameters */
221221
double VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
222-
222+
223+
private:
223224
/*! \brief Samplers */
224225
MultivariateNormalSampler normal_sampler_;
225226
InverseGammaSampler ig_sampler_;

src/random_effects.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects
4545
AddCurrentPredictionToResidual(dataset, rfx_tracker, residual);
4646

4747
// Sample random effects
48-
SampleWorkingParameter(dataset, residual, rfx_tracker, global_variance, gen);
4948
SampleGroupParameters(dataset, residual, rfx_tracker, global_variance, gen);
49+
SampleWorkingParameter(dataset, residual, rfx_tracker, global_variance, gen);
5050
SampleVarianceComponents(dataset, residual, rfx_tracker, global_variance, gen);
5151

5252
// Update partial residual to remove the random effects
@@ -104,8 +104,8 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R
104104
X_group = X(observation_indices, Eigen::all);
105105
y_group = y(observation_indices, Eigen::all);
106106
xi_group = xi(Eigen::all, i);
107-
posterior_denominator += (xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal();
108-
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group;
107+
posterior_denominator += ((xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal()) / global_variance;
108+
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group / global_variance;
109109
}
110110
return posterior_denominator.inverse() * posterior_numerator;
111111
}
@@ -127,8 +127,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVarian
127127
X_group = X(observation_indices, Eigen::all);
128128
y_group = y(observation_indices, Eigen::all);
129129
xi_group = xi(Eigen::all, i);
130-
posterior_denominator += (xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal();
131-
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group;
130+
posterior_denominator += ((xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal()) / (global_variance);
132131
}
133132
return posterior_denominator.inverse();
134133
}
@@ -144,8 +143,8 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran
144143
std::vector<data_size_t> observation_indices = rfx_tracker.NodeIndicesInternalIndex(group_id);
145144
Eigen::MatrixXd X_group = X(observation_indices, Eigen::all);
146145
Eigen::VectorXd y_group = y(observation_indices, Eigen::all);
147-
posterior_denominator += (alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal();
148-
posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group;
146+
posterior_denominator += ((alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal()) / (global_variance);
147+
posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group / global_variance;
149148
return posterior_denominator.inverse() * posterior_numerator;
150149
}
151150

@@ -160,7 +159,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance
160159
std::vector<data_size_t> observation_indices = rfx_tracker.NodeIndicesInternalIndex(group_id);
161160
Eigen::MatrixXd X_group = X(observation_indices, Eigen::all);
162161
// Eigen::VectorXd y_group = y(observation_indices, Eigen::all);
163-
posterior_denominator += (alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal();
162+
posterior_denominator += ((alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal()) / (global_variance);
164163
// posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group;
165164
return posterior_denominator.inverse();
166165
}

test/cpp/test_random_effects.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,63 @@ TEST(RandomEffects, Construction) {
9999
}
100100
}
101101

102+
TEST(RandomEffects, Computation) {
103+
// Load test data
104+
StochTree::TestUtils::TestDataset test_dataset;
105+
test_dataset = StochTree::TestUtils::LoadSmallRFXDatasetMultivariateBasis();
106+
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
107+
108+
// Construct dataset
109+
int n = test_dataset.n;
110+
StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), test_dataset.n);
111+
StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset();
112+
dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major);
113+
dataset.AddGroupLabels(test_dataset.rfx_groups);
114+
115+
// Construct tracker, model state, and container
116+
StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups);
117+
StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
118+
StochTree::RandomEffectsContainer container = StochTree::RandomEffectsContainer(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
119+
StochTree::LabelMapper label_mapper = StochTree::LabelMapper(tracker.GetLabelMap());
120+
121+
// Set the values of alpha, xi and sigma in the model state (rather than simulating)
122+
Eigen::VectorXd alpha(test_dataset.rfx_basis_cols);
123+
Eigen::MatrixXd xi(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
124+
Eigen::MatrixXd sigma(test_dataset.rfx_basis_cols, test_dataset.rfx_basis_cols);
125+
alpha << 1., 1.;
126+
xi << 1., 1., 1., 1., 1., 1.;
127+
Eigen::VectorXd xi0 = xi(Eigen::all, 0);
128+
Eigen::VectorXd xi1 = xi(Eigen::all, 1);
129+
Eigen::VectorXd xi2 = xi(Eigen::all, 2);
130+
sigma << 1, 0, 0, 1;
131+
model.SetWorkingParameter(alpha);
132+
model.SetGroupParameter(xi0, 0);
133+
model.SetGroupParameter(xi1, 1);
134+
model.SetGroupParameter(xi2, 2);
135+
model.SetGroupParameterCovariance(sigma);
136+
double sigma2 = 1.;
137+
138+
// Compute the posterior mean for the group parameters
139+
Eigen::VectorXd xi0_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 0);
140+
Eigen::VectorXd xi1_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 1);
141+
Eigen::VectorXd xi2_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 2);
142+
143+
// Check data in the container
144+
std::vector<double> xi_mean_expected(test_dataset.rfx_basis_cols);
145+
xi_mean_expected = {0.6979496, 0.3316027};
146+
for (int i = 0; i < xi_mean_expected.size(); i++) {
147+
ASSERT_NEAR(xi0_mean(i), xi_mean_expected[i], 0.001);
148+
}
149+
xi_mean_expected = {0.65744523, 0.00639347};
150+
for (int i = 0; i < xi_mean_expected.size(); i++) {
151+
ASSERT_NEAR(xi1_mean(i), xi_mean_expected[i], 0.001);
152+
}
153+
xi_mean_expected = {0.8763421, 0.3414047};
154+
for (int i = 0; i < xi_mean_expected.size(); i++) {
155+
ASSERT_NEAR(xi2_mean(i), xi_mean_expected[i], 0.001);
156+
}
157+
}
158+
102159
TEST(RandomEffects, Predict) {
103160
// Load test data
104161
StochTree::TestUtils::TestDataset test_dataset;

test/cpp/testutils.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,74 @@ TestDataset LoadSmallDatasetMultivariateBasis() {
118118
return output;
119119
}
120120

121+
TestDataset LoadSmallRFXDatasetMultivariateBasis() {
122+
TestDataset output;
123+
124+
// Data dimensions
125+
output.n = 10;
126+
output.x_cols = 5;
127+
output.omega_cols = 2;
128+
output.rfx_basis_cols = 2;
129+
output.covariates.resize(output.n, output.x_cols);
130+
output.omega.resize(output.n, output.omega_cols);
131+
output.rfx_basis.resize(output.n, output.rfx_basis_cols);
132+
output.rfx_groups.resize(output.n);
133+
output.outcome.resize(output.n);
134+
135+
// Covariates
136+
output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269,
137+
0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451,
138+
0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373,
139+
0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269,
140+
0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742,
141+
0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962,
142+
0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501,
143+
0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664,
144+
0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386,
145+
0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569;
146+
147+
// Leaf regression basis
148+
output.omega << 0.97801674, 0.3707159,
149+
0.34045661, 0.1312134,
150+
0.20528387, 0.5614470,
151+
0.76230322, 0.2276504,
152+
0.63244655, 0.9029984,
153+
0.61225851, 0.7448547,
154+
0.40492125, 0.2549813,
155+
0.33112223, 0.5295535,
156+
0.86917047, 0.5584614,
157+
0.58444831, 0.2365117;
158+
159+
// Outcome
160+
output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379,
161+
0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414;
162+
163+
// Random effects regression basis (i.e. constant, intercept-only RFX model)
164+
output.rfx_basis << 1, 0.3707159,
165+
1, 0.1312134,
166+
1, 0.5614470,
167+
1, 0.2276504,
168+
1, 0.9029984,
169+
1, 0.7448547,
170+
1, 0.2549813,
171+
1, 0.5295535,
172+
1, 0.5584614,
173+
1, 0.2365117;
174+
175+
// Random effects group labels
176+
output.rfx_groups = {1,2,3,1,2,3,1,2,3,1};
177+
// for (int i = 0; i < output.n; i++) {
178+
// if (i % 2 == 0) {
179+
// output.rfx_groups[i] = 1;
180+
// } else {
181+
// output.rfx_groups[i] = 2;
182+
// }
183+
// }
184+
output.rfx_num_groups = 3;
185+
186+
return output;
187+
}
188+
121189
TestDataset LoadMediumDatasetUnivariateBasis() {
122190
TestDataset output;
123191

test/cpp/testutils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ TestDataset LoadSmallDatasetUnivariateBasis();
3333
/*! Creates a small dataset (10 observations) with a multivariate basis for leaf regression applications */
3434
TestDataset LoadSmallDatasetMultivariateBasis();
3535

36+
/*! Creates a small dataset (10 observations) with a multivariate basis and several random effects terms */
37+
TestDataset LoadSmallRFXDatasetMultivariateBasis();
38+
3639
/*! Creates a modest dataset (100 observations) */
3740
TestDataset LoadMediumDatasetUnivariateBasis();
3841

0 commit comments

Comments
 (0)