File tree Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -96,6 +96,7 @@ class CategorySampleTracker {
96
96
if (i > 0 ) new_group_cond = group_indices[indices_[i]] != group_indices[indices_[i-1 ]];
97
97
if (start_cond || new_group_cond) {
98
98
category_id_map_.insert ({group_indices[indices_[i]], category_count_});
99
+ unique_category_ids_.push_back (group_indices[indices_[i]]);
99
100
node_index_vector_.emplace_back ();
100
101
if (i == 0 ) {
101
102
category_begin_.push_back (i);
@@ -163,11 +164,14 @@ class CategorySampleTracker {
163
164
/* ! \brief Returns label index map */
164
165
std::map<int32_t , int32_t >& GetLabelMap () {return category_id_map_;}
165
166
167
+ std::vector<int32_t >& GetUniqueGroupIds () {return unique_category_ids_;}
168
+
166
169
private:
167
170
// Vectors tracking indices in each node
168
171
std::vector<data_size_t > category_begin_;
169
172
std::vector<data_size_t > category_length_;
170
173
std::map<int32_t , int32_t > category_id_map_;
174
+ std::vector<int32_t > unique_category_ids_;
171
175
std::vector<std::vector<data_size_t >> node_index_vector_;
172
176
int32_t category_count_;
173
177
};
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ class RandomEffectsTracker {
43
43
std::vector<data_size_t >::iterator UnsortedNodeBeginIterator (int category_id);
44
44
std::vector<data_size_t >::iterator UnsortedNodeEndIterator (int category_id);
45
45
std::map<int32_t , int32_t >& GetLabelMap () {return category_sample_tracker_->GetLabelMap ();}
46
+ std::vector<int32_t >& GetUniqueGroupIds () {return category_sample_tracker_->GetUniqueGroupIds ();}
46
47
std::vector<data_size_t >& NodeIndices (int category_id) {return category_sample_tracker_->NodeIndices (category_id);}
47
48
std::vector<data_size_t >& NodeIndicesInternalIndex (int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex (internal_category_id);}
48
49
double GetPrediction (data_size_t observation_num) {return rfx_predictions_.at (observation_num);}
@@ -248,6 +249,9 @@ class RandomEffectsContainer {
248
249
~RandomEffectsContainer () {}
249
250
void AddSample (MultivariateRegressionRandomEffectsModel& model);
250
251
void Predict (RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double >& output);
252
+ int NumSamples () {return num_samples_;}
253
+ int NumComponents () {return num_components_;}
254
+ int NumGroups () {return num_groups_;}
251
255
std::vector<double > GetBeta () {return beta_;}
252
256
std::vector<double > GetAlpha () {return alpha_;}
253
257
std::vector<double > GetXi () {return xi_;}
You can’t perform that action at this time.
0 commit comments