Skip to content

Commit 9ee9d20

Browse files
committed
Adding json serialization to random effects container
1 parent 596e960 commit 9ee9d20

File tree

4 files changed

+155
-10
lines changed

4 files changed

+155
-10
lines changed

debug/api_debug.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ void RunAPI() {
280280
}
281281

282282
// Run the MCMC sampler
283-
int num_mcmc_samples = 100;
283+
int num_mcmc_samples = 10;
284284
for (int i = num_gfr_samples; i < num_gfr_samples + num_mcmc_samples; i++) {
285285
if (i == 0) {
286286
global_variance = global_variance_init;
@@ -313,17 +313,16 @@ void RunAPI() {
313313
std::vector<double> rfx_predictions(n*num_samples);
314314
rfx_container.Predict(rfx_dataset, label_mapper, rfx_predictions);
315315

316-
// // Write model to a file
317-
// std::string filename = "model.json";
318-
// forest_samples.SaveToJsonFile(filename);
316+
// Write model to a file
317+
std::string filename = "model.json";
318+
forest_samples.SaveToJsonFile(filename);
319319

320-
// // Read and parse json from file
321-
// ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant);
322-
// forest_samples_parsed.LoadFromJsonFile(filename);
320+
// Read and parse json from file
321+
ForestContainer forest_samples_parsed = ForestContainer(num_trees, output_dimension, is_leaf_constant);
322+
forest_samples_parsed.LoadFromJsonFile(filename);
323323

324-
// // Make sure we can predict from both the original and parsed forest containers
325-
// std::vector<double> pred_orig = forest_samples.Predict(dataset);
326-
// std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
324+
// Make sure we can predict from both the original (above) and parsed forest containers
325+
std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
327326
}
328327

329328
} // namespace StochTree

include/stochtree/random_effects.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <stochtree/normal_sampler.h>
1515
#include <stochtree/partition_tracker.h>
1616
#include <stochtree/prior.h>
17+
#include <nlohmann/json.hpp>
1718
#include <Eigen/Dense>
1819

1920
#include <cmath>
@@ -247,6 +248,11 @@ class RandomEffectsContainer {
247248
num_groups_ = num_groups;
248249
num_samples_ = 0;
249250
}
251+
RandomEffectsContainer() {
252+
num_components_ = 0;
253+
num_groups_ = 0;
254+
num_samples_ = 0;
255+
}
250256
~RandomEffectsContainer() {}
251257
void AddSample(MultivariateRegressionRandomEffectsModel& model);
252258
void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
@@ -257,6 +263,8 @@ class RandomEffectsContainer {
257263
std::vector<double>& GetAlpha() {return alpha_;}
258264
std::vector<double>& GetXi() {return xi_;}
259265
std::vector<double>& GetSigma() {return sigma_xi_;}
266+
nlohmann::json to_json();
267+
void from_json(const nlohmann::json& rfx_container_json);
260268
private:
261269
int num_samples_;
262270
int num_components_;

src/random_effects.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,72 @@ void RandomEffectsContainer::Predict(RandomEffectsDataset& dataset, LabelMapper&
199199
}
200200
}
201201

202+
nlohmann::json RandomEffectsContainer::to_json() {
203+
json result_obj;
204+
// Store the non-array fields in json
205+
result_obj.emplace("num_samples", num_samples_);
206+
result_obj.emplace("num_components", num_components_);
207+
result_obj.emplace("num_groups", num_groups_);
208+
209+
// Store some meta-level information about the containers
210+
int beta_size = num_groups_*num_components_*num_samples_;
211+
int alpha_size = num_components_*num_samples_;
212+
result_obj.emplace("beta_size", beta_size);
213+
result_obj.emplace("alpha_size", alpha_size);
214+
215+
// Initialize a map with names of the node vectors and empty json arrays
216+
std::map<std::string, json> tree_array_map;
217+
tree_array_map.emplace(std::pair("beta", json::array()));
218+
tree_array_map.emplace(std::pair("xi", json::array()));
219+
tree_array_map.emplace(std::pair("alpha", json::array()));
220+
tree_array_map.emplace(std::pair("sigma_xi", json::array()));
221+
222+
// Unpack beta and xi into json arrays
223+
for (int i = 0; i < beta_size; i++) {
224+
tree_array_map["beta"].emplace_back(beta_.at(i));
225+
tree_array_map["xi"].emplace_back(xi_.at(i));
226+
}
227+
228+
// Unpack alpha and sigma into json arrays
229+
for (int i = 0; i < alpha_size; i++) {
230+
tree_array_map["alpha"].emplace_back(alpha_.at(i));
231+
tree_array_map["sigma_xi"].emplace_back(sigma_xi_.at(i));
232+
}
233+
234+
// Unpack the map into the reference JSON object
235+
for (auto& pair : tree_array_map) {
236+
result_obj.emplace(pair);
237+
}
238+
239+
return result_obj;
240+
}
241+
242+
void RandomEffectsContainer::from_json(const nlohmann::json& rfx_container_json) {
243+
int beta_size = rfx_container_json.at("beta_size");
244+
int alpha_size = rfx_container_json.at("alpha_size");
245+
246+
// Clear all internal arrays
247+
beta_.clear();
248+
xi_.clear();
249+
alpha_.clear();
250+
sigma_xi_.clear();
251+
252+
// Unpack internal counts
253+
this->num_samples_ = rfx_container_json.at("num_samples");
254+
this->num_components_ = rfx_container_json.at("num_components");
255+
this->num_groups_ = rfx_container_json.at("num_groups");
256+
257+
// Unpack beta and xi
258+
for (int i = 0; i < beta_size; i++) {
259+
beta_.push_back(rfx_container_json.at("beta").at(i));
260+
xi_.push_back(rfx_container_json.at("xi").at(i));
261+
}
262+
263+
// Unpack alpha and sigma_xi
264+
for (int i = 0; i < alpha_size; i++) {
265+
alpha_.push_back(rfx_container_json.at("alpha").at(i));
266+
sigma_xi_.push_back(rfx_container_json.at("sigma_xi").at(i));
267+
}
268+
}
269+
202270
} // namespace StochTree

test/test_random_effects.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,73 @@ TEST(RandomEffects, Predict) {
160160
ASSERT_EQ(output[i], output_expected[i]);
161161
}
162162
}
163+
164+
TEST(RandomEffects, Serialization) {
165+
// Load test data
166+
StochTree::TestUtils::TestDataset test_dataset;
167+
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
168+
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
169+
170+
// Construct dataset
171+
int n = test_dataset.n;
172+
StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset();
173+
dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major);
174+
dataset.AddGroupLabels(test_dataset.rfx_groups);
175+
176+
// Construct tracker, model state, and container
177+
StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups);
178+
StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
179+
StochTree::RandomEffectsContainer container = StochTree::RandomEffectsContainer(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
180+
StochTree::LabelMapper label_mapper = StochTree::LabelMapper(tracker.GetLabelMap());
181+
182+
// Set the values of alpha, xi and sigma in the model state (rather than simulating)
183+
Eigen::VectorXd alpha(test_dataset.rfx_basis_cols);
184+
Eigen::MatrixXd xi(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
185+
Eigen::MatrixXd sigma(test_dataset.rfx_basis_cols, test_dataset.rfx_basis_cols);
186+
alpha << 1.5;
187+
xi << 2, 4;
188+
Eigen::VectorXd xi0 = xi(Eigen::all, 0);
189+
Eigen::VectorXd xi1 = xi(Eigen::all, 1);
190+
sigma << 1;
191+
model.SetWorkingParameter(alpha);
192+
model.SetGroupParameter(xi0, 0);
193+
model.SetGroupParameter(xi1, 1);
194+
model.SetGroupParameterCovariance(sigma);
195+
196+
// Push to the container
197+
container.AddSample(model);
198+
199+
// Change values and push a second "sample" to the container
200+
alpha << 2.0;
201+
xi << 1, 3;
202+
xi0 = xi(Eigen::all, 0);
203+
xi1 = xi(Eigen::all, 1);
204+
sigma << 1;
205+
model.SetWorkingParameter(alpha);
206+
model.SetGroupParameter(xi0, 0);
207+
model.SetGroupParameter(xi1, 1);
208+
model.SetGroupParameterCovariance(sigma);
209+
container.AddSample(model);
210+
211+
// Json round trip
212+
nlohmann::json container_json = container.to_json();
213+
StochTree::RandomEffectsContainer container_deserialized = StochTree::RandomEffectsContainer();
214+
container_deserialized.from_json(container_json);
215+
216+
// Check data in the container
217+
std::vector<double> alpha_original = container.GetAlpha();
218+
std::vector<double> alpha_deserialized = container_deserialized.GetAlpha();
219+
for (int i = 0; i < alpha_deserialized.size(); i++) {
220+
ASSERT_EQ(alpha_original[i], alpha_deserialized[i]);
221+
}
222+
std::vector<double> xi_original = container.GetXi();
223+
std::vector<double> xi_deserialized = container_deserialized.GetXi();
224+
for (int i = 0; i < xi_deserialized.size(); i++) {
225+
ASSERT_EQ(xi_original[i], xi_deserialized[i]);
226+
}
227+
std::vector<double> beta_original = container.GetBeta();
228+
std::vector<double> beta_deserialized = container_deserialized.GetBeta();
229+
for (int i = 0; i < beta_deserialized.size(); i++) {
230+
ASSERT_EQ(beta_original[i], beta_deserialized[i]);
231+
}
232+
}

0 commit comments

Comments
 (0)