Skip to content

Commit 560922a

Browse files
committed
Added convenience method to random effects and category tracker classes
1 parent 06af0f9 commit 560922a

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

include/stochtree/category_tracker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class CategorySampleTracker {
9696
if (i > 0) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1]];
9797
if (start_cond || new_group_cond) {
9898
category_id_map_.insert({group_indices[indices_[i]], category_count_});
99+
unique_category_ids_.push_back(group_indices[indices_[i]]);
99100
node_index_vector_.emplace_back();
100101
if (i == 0) {
101102
category_begin_.push_back(i);
@@ -163,11 +164,14 @@ class CategorySampleTracker {
163164
/*! \brief Returns label index map */
164165
std::map<int32_t, int32_t>& GetLabelMap() {return category_id_map_;}
165166

167+
std::vector<int32_t>& GetUniqueGroupIds() {return unique_category_ids_;}
168+
166169
private:
167170
// Vectors tracking indices in each node
168171
std::vector<data_size_t> category_begin_;
169172
std::vector<data_size_t> category_length_;
170173
std::map<int32_t, int32_t> category_id_map_;
174+
std::vector<int32_t> unique_category_ids_;
171175
std::vector<std::vector<data_size_t>> node_index_vector_;
172176
int32_t category_count_;
173177
};

include/stochtree/random_effects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class RandomEffectsTracker {
4343
std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id);
4444
std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id);
4545
std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();}
46+
std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();}
4647
std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);}
4748
std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
4849
double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);}
@@ -248,6 +249,9 @@ class RandomEffectsContainer {
248249
~RandomEffectsContainer() {}
249250
void AddSample(MultivariateRegressionRandomEffectsModel& model);
250251
void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
252+
int NumSamples() {return num_samples_;}
253+
int NumComponents() {return num_components_;}
254+
int NumGroups() {return num_groups_;}
251255
std::vector<double> GetBeta() {return beta_;}
252256
std::vector<double> GetAlpha() {return alpha_;}
253257
std::vector<double> GetXi() {return xi_;}

0 commit comments

Comments
 (0)