Skip to content

Commit 424dbad

Browse files
authored
Merge pull request #40 from StochasticTree/rfx_serialization
Added json serialization code for random effects
2 parents b984dae + 809f558 commit 424dbad

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

include/stochtree/random_effects.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class LabelMapper {
7979
}
8080
std::vector<int32_t>& Keys() {return keys_;}
8181
std::map<int32_t, int32_t>& Map() {return label_map_;}
82+
void Reset() {label_map_.clear(); keys_.clear();}
83+
nlohmann::json to_json();
84+
void from_json(const nlohmann::json& rfx_label_mapper_json);
8285
private:
8386
std::map<int32_t, int32_t> label_map_;
8487
std::vector<int32_t> keys_;

src/random_effects.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@ RandomEffectsTracker::RandomEffectsTracker(std::vector<int32_t>& group_indices)
1111
rfx_predictions_.resize(num_observations_, 0.);
1212
}
1313

14+
nlohmann::json LabelMapper::to_json() {
15+
json output_obj;
16+
// Initialize a map with names of the node vectors and empty json arrays
17+
std::map<std::string, json> label_map_arrays;
18+
label_map_arrays.emplace(std::pair("keys", json::array()));
19+
label_map_arrays.emplace(std::pair("values", json::array()));
20+
for (const auto& [key, value] : label_map_) {
21+
label_map_arrays["keys"].emplace_back(key);
22+
label_map_arrays["values"].emplace_back(value);
23+
}
24+
for (auto& pair : label_map_arrays) {
25+
output_obj.emplace(pair);
26+
}
27+
return output_obj;
28+
}
29+
30+
void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) {
31+
int num_keys = rfx_label_mapper_json.at("keys").size();
32+
int num_values = rfx_label_mapper_json.at("values").size();
33+
CHECK_EQ(num_keys, num_values);
34+
for (int i = 0; i < num_keys; i++) {
35+
keys_.push_back(rfx_label_mapper_json.at("keys").at(i));
36+
label_map_.insert({rfx_label_mapper_json.at("keys").at(i), rfx_label_mapper_json.at("values").at(i)});
37+
}
38+
}
39+
1440
void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker,
1541
double global_variance, std::mt19937& gen) {
1642
// Update partial residual to add back in the random effects

0 commit comments

Comments
 (0)