@@ -169,10 +169,34 @@ void RunAPI() {
169
169
double outcome_scale;
170
170
OutcomeOffsetScale (residual, outcome_offset, outcome_scale);
171
171
172
- // // Construct a random effects dataset
173
- // RandomEffectsDataset rfx_dataset = RandomEffectsDataset();
174
- // rfx_dataset.AddBasis(rfx_basis_raw.data(), n, rfx_basis_cols, row_major);
175
- // rfx_dataset.AddGroupLabels(rfx_groups);
172
+ // Construct a random effects dataset
173
+ RandomEffectsDataset rfx_dataset = RandomEffectsDataset ();
174
+ rfx_dataset.AddBasis (rfx_basis_raw.data (), n, rfx_basis_cols, true );
175
+ rfx_dataset.AddGroupLabels (rfx_groups);
176
+
177
+ // Construct random effects tracker / model / container
178
+ RandomEffectsTracker rfx_tracker = RandomEffectsTracker (rfx_groups);
179
+ MultivariateRegressionRandomEffectsModel rfx_model = MultivariateRegressionRandomEffectsModel (rfx_basis_cols, num_rfx_groups);
180
+ RandomEffectsContainer rfx_container = RandomEffectsContainer (rfx_basis_cols, num_rfx_groups);
181
+ LabelMapper label_mapper = LabelMapper (rfx_tracker.GetLabelMap ());
182
+
183
+ // Set random effects model parameters
184
+ Eigen::VectorXd working_param_init (rfx_basis_cols);
185
+ Eigen::MatrixXd group_param_init (rfx_basis_cols, num_rfx_groups);
186
+ Eigen::MatrixXd working_param_cov_init (rfx_basis_cols, rfx_basis_cols);
187
+ Eigen::MatrixXd group_param_cov_init (rfx_basis_cols, rfx_basis_cols);
188
+ double variance_prior_shape = 1 .;
189
+ double variance_prior_scale = 1 .;
190
+ working_param_init << 1 .;
191
+ group_param_init << 1 ., 1 .;
192
+ working_param_cov_init << 1 ;
193
+ group_param_cov_init << 1 ;
194
+ rfx_model.SetWorkingParameter (working_param_init);
195
+ rfx_model.SetGroupParameters (group_param_init);
196
+ rfx_model.SetWorkingParameterCovariance (working_param_cov_init);
197
+ rfx_model.SetGroupParameterCovariance (group_param_cov_init);
198
+ rfx_model.SetVariancePriorShape (variance_prior_shape);
199
+ rfx_model.SetVariancePriorScale (variance_prior_scale);
176
200
177
201
// Initialize an ensemble
178
202
int num_trees = 100 ;
@@ -244,6 +268,10 @@ void RunAPI() {
244
268
sampleGFR (tracker, tree_prior, forest_samples, dataset, residual, rng, feature_types, variable_weights,
245
269
leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size);
246
270
271
+ // Sample random effects
272
+ rfx_model.SampleRandomEffects (rfx_dataset, residual, rfx_tracker, global_variance, rng);
273
+ rfx_container.AddSample (rfx_model);
274
+
247
275
// Sample leaf node variance
248
276
leaf_variance_samples.push_back (leaf_var_model.SampleVarianceParameter (forest_samples.GetEnsemble (i), a_leaf, b_leaf, rng));
249
277
@@ -266,24 +294,36 @@ void RunAPI() {
266
294
sampleMCMC (tracker, tree_prior, forest_samples, dataset, residual, rng, feature_types, variable_weights,
267
295
leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size);
268
296
297
+ // Sample random effects
298
+ rfx_model.SampleRandomEffects (rfx_dataset, residual, rfx_tracker, global_variance, rng);
299
+ rfx_container.AddSample (rfx_model);
300
+
269
301
// Sample leaf node variance
270
302
leaf_variance_samples.push_back (leaf_var_model.SampleVarianceParameter (forest_samples.GetEnsemble (i), a_leaf, b_leaf, rng));
271
303
272
304
// Sample global variance
273
305
global_variance_samples.push_back (global_var_model.SampleVarianceParameter (residual.GetData (), nu, nu*lamb, rng));
274
306
}
275
307
276
- // Write model to a file
277
- std::string filename = " model.json" ;
278
- forest_samples.SaveToJsonFile (filename);
308
+ // Predict from the tree ensemble
309
+ std::vector<double > pred_orig = forest_samples.Predict (dataset);
310
+
311
+ // Predict from the random effects dataset
312
+ int num_samples = num_gfr_samples + num_mcmc_samples;
313
+ std::vector<double > rfx_predictions (n*num_samples);
314
+ rfx_container.Predict (rfx_dataset, label_mapper, rfx_predictions);
279
315
280
- // Read and parse json from file
281
- ForestContainer forest_samples_parsed = ForestContainer (num_trees, output_dimension, is_leaf_constant);
282
- forest_samples_parsed.LoadFromJsonFile (filename);
316
+ // // Write model to a file
317
+ // std::string filename = "model.json";
318
+ // forest_samples.SaveToJsonFile(filename);
319
+
320
+ // // Read and parse json from file
321
+ // ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant);
322
+ // forest_samples_parsed.LoadFromJsonFile(filename);
283
323
284
- // Make sure we can predict from both the original and parsed forest containers
285
- std::vector<double > pred_orig = forest_samples.Predict (dataset);
286
- std::vector<double > pred_parsed = forest_samples_parsed.Predict (dataset);
324
+ // // Make sure we can predict from both the original and parsed forest containers
325
+ // std::vector<double> pred_orig = forest_samples.Predict(dataset);
326
+ // std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
287
327
}
288
328
289
329
} // namespace StochTree
0 commit comments