Skip to content

Commit 566a75c

Browse files
committed
Merge branch 'main' into sparse_leaf_matrix
2 parents 0642930 + 44e3be1 commit 566a75c

File tree

9 files changed

+497
-196
lines changed

9 files changed

+497
-196
lines changed

debug/api_debug.cpp

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,34 @@ void RunAPI() {
169169
double outcome_scale;
170170
OutcomeOffsetScale(residual, outcome_offset, outcome_scale);
171171

172-
// // Construct a random effects dataset
173-
// RandomEffectsDataset rfx_dataset = RandomEffectsDataset();
174-
// rfx_dataset.AddBasis(rfx_basis_raw.data(), n, rfx_basis_cols, row_major);
175-
// rfx_dataset.AddGroupLabels(rfx_groups);
172+
// Construct a random effects dataset
173+
RandomEffectsDataset rfx_dataset = RandomEffectsDataset();
174+
rfx_dataset.AddBasis(rfx_basis_raw.data(), n, rfx_basis_cols, true);
175+
rfx_dataset.AddGroupLabels(rfx_groups);
176+
177+
// Construct random effects tracker / model / container
178+
RandomEffectsTracker rfx_tracker = RandomEffectsTracker(rfx_groups);
179+
MultivariateRegressionRandomEffectsModel rfx_model = MultivariateRegressionRandomEffectsModel(rfx_basis_cols, num_rfx_groups);
180+
RandomEffectsContainer rfx_container = RandomEffectsContainer(rfx_basis_cols, num_rfx_groups);
181+
LabelMapper label_mapper = LabelMapper(rfx_tracker.GetLabelMap());
182+
183+
// Set random effects model parameters
184+
Eigen::VectorXd working_param_init(rfx_basis_cols);
185+
Eigen::MatrixXd group_param_init(rfx_basis_cols, num_rfx_groups);
186+
Eigen::MatrixXd working_param_cov_init(rfx_basis_cols, rfx_basis_cols);
187+
Eigen::MatrixXd group_param_cov_init(rfx_basis_cols, rfx_basis_cols);
188+
double variance_prior_shape = 1.;
189+
double variance_prior_scale = 1.;
190+
working_param_init << 1.;
191+
group_param_init << 1., 1.;
192+
working_param_cov_init << 1;
193+
group_param_cov_init << 1;
194+
rfx_model.SetWorkingParameter(working_param_init);
195+
rfx_model.SetGroupParameters(group_param_init);
196+
rfx_model.SetWorkingParameterCovariance(working_param_cov_init);
197+
rfx_model.SetGroupParameterCovariance(group_param_cov_init);
198+
rfx_model.SetVariancePriorShape(variance_prior_shape);
199+
rfx_model.SetVariancePriorScale(variance_prior_scale);
176200

177201
// Initialize an ensemble
178202
int num_trees = 100;
@@ -244,6 +268,10 @@ void RunAPI() {
244268
sampleGFR(tracker, tree_prior, forest_samples, dataset, residual, rng, feature_types, variable_weights,
245269
leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size);
246270

271+
// Sample random effects
272+
rfx_model.SampleRandomEffects(rfx_dataset, residual, rfx_tracker, global_variance, rng);
273+
rfx_container.AddSample(rfx_model);
274+
247275
// Sample leaf node variance
248276
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, rng));
249277

@@ -266,24 +294,36 @@ void RunAPI() {
266294
sampleMCMC(tracker, tree_prior, forest_samples, dataset, residual, rng, feature_types, variable_weights,
267295
leaf_model_type, leaf_scale_matrix, global_variance, leaf_scale, cutpoint_grid_size);
268296

297+
// Sample random effects
298+
rfx_model.SampleRandomEffects(rfx_dataset, residual, rfx_tracker, global_variance, rng);
299+
rfx_container.AddSample(rfx_model);
300+
269301
// Sample leaf node variance
270302
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, rng));
271303

272304
// Sample global variance
273305
global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng));
274306
}
275307

276-
// Write model to a file
277-
std::string filename = "model.json";
278-
forest_samples.SaveToJsonFile(filename);
308+
// Predict from the tree ensemble
309+
std::vector<double> pred_orig = forest_samples.Predict(dataset);
310+
311+
// Predict from the random effects dataset
312+
int num_samples = num_gfr_samples + num_mcmc_samples;
313+
std::vector<double> rfx_predictions(n*num_samples);
314+
rfx_container.Predict(rfx_dataset, label_mapper, rfx_predictions);
279315

280-
// Read and parse json from file
281-
ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant);
282-
forest_samples_parsed.LoadFromJsonFile(filename);
316+
// // Write model to a file
317+
// std::string filename = "model.json";
318+
// forest_samples.SaveToJsonFile(filename);
319+
320+
// // Read and parse json from file
321+
// ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant);
322+
// forest_samples_parsed.LoadFromJsonFile(filename);
283323

284-
// Make sure we can predict from both the original and parsed forest containers
285-
std::vector<double> pred_orig = forest_samples.Predict(dataset);
286-
std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
324+
// // Make sure we can predict from both the original and parsed forest containers
325+
// std::vector<double> pred_orig = forest_samples.Predict(dataset);
326+
// std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
287327
}
288328

289329
} // namespace StochTree

include/stochtree/category_tracker.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class CategorySampleTracker {
9696
if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
9797
if (start_cond || new_group_cond) {
9898
category_id_map_.insert({group_indices[indices_[i]], category_count_});
99+
unique_category_ids_.push_back(group_indices[indices_[i]]);
99100
node_index_vector_.emplace_back();
100101
if (i == 0) {
101102
category_begin_.push_back(i);
@@ -115,6 +116,11 @@ class CategorySampleTracker {
115116
}
116117
}
117118

119+
/*! \brief Zero-indexed numeric index that category_id is remapped to internally */
120+
inline int32_t CategoryNumber(int category_id) {
121+
return category_id_map_[category_id];
122+
}
123+
118124
/*! \brief First index of data points contained in node_id */
119125
inline data_size_t CategoryBegin(int category_id) {return category_begin_[category_id_map_[category_id]];}
120126

@@ -144,15 +150,28 @@ class CategorySampleTracker {
144150
// return output;
145151
return node_index_vector_[id];
146152
}
153+
154+
/*! \brief Data indices for a given node */
155+
std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {
156+
// int32_t id = category_id_map_[category_id];
157+
// std::vector<data_size_t>::iterator start = indices_.begin() + category_begin_[id];
158+
// std::vector<data_size_t>::iterator end = indices_.begin() + category_begin_[id] + category_length_[id];
159+
// std::vector<data_size_t> output(start, end);
160+
// return output;
161+
return node_index_vector_[internal_category_id];
162+
}
147163

148164
/*! \brief Returns label index map */
149165
std::map<int32_t, int32_t>& GetLabelMap() {return category_id_map_;}
150166

167+
std::vector<int32_t>& GetUniqueGroupIds() {return unique_category_ids_;}
168+
151169
private:
152170
// Vectors tracking indices in each node
153171
std::vector<data_size_t> category_begin_;
154172
std::vector<data_size_t> category_length_;
155173
std::map<int32_t, int32_t> category_id_map_;
174+
std::vector<int32_t> unique_category_ids_;
156175
std::vector<std::vector<data_size_t>> node_index_vector_;
157176
int32_t category_count_;
158177
};

include/stochtree/data.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class RandomEffectsDataset {
118118
group_labels_ = group_labels;
119119
has_group_labels_ = true;
120120
}
121+
inline data_size_t NumObservations() {return basis_.NumRows();}
121122
inline bool HasBasis() {return has_basis_;}
122123
inline bool HasVarWeights() {return has_var_weights_;}
123124
inline bool HasGroupLabels() {return has_group_labels_;}

0 commit comments

Comments
 (0)