Skip to content

Commit 990c0a7

Browse files
authored
Merge pull request #34 from StochasticTree/prediction_cache
Added tree prediction caching to the ForestTracker
2 parents fc67100 + 0755fd0 commit 990c0a7

File tree

3 files changed

+87
-9
lines changed

3 files changed

+87
-9
lines changed

include/stochtree/partition_tracker.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ namespace StochTree {
4141

4242
/*! \brief Forward declarations of component classes */
4343
class SampleNodeMapper;
44+
class SamplePredMapper;
4445
class UnsortedNodeSampleTracker;
4546
class SortedNodeSampleTracker;
4647
class FeaturePresortRootContainer;
@@ -52,9 +53,13 @@ class ForestTracker {
5253
~ForestTracker() {}
5354
void AssignAllSamplesToRoot();
5455
void AssignAllSamplesToRoot(int32_t tree_num);
56+
void AssignAllSamplesToConstantPrediction(double value);
57+
void AssignAllSamplesToConstantPrediction(int32_t tree_num, double value);
5558
void ResetRoot(Eigen::MatrixXd& covariates, std::vector<FeatureType>& feature_types, int32_t tree_num);
5659
void AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false);
5760
void RemoveSplit(Eigen::MatrixXd& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false);
61+
double GetTreeSamplePrediction(data_size_t sample_id, int tree_id);
62+
void SetTreeSamplePrediction(data_size_t sample_id, int tree_id, double value);
5863
data_size_t GetNodeId(int observation_num, int tree_num);
5964
data_size_t UnsortedNodeBegin(int tree_id, int node_id);
6065
data_size_t UnsortedNodeEnd(int tree_id, int node_id);
@@ -66,11 +71,14 @@ class ForestTracker {
6671
std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int tree_id, int node_id);
6772
std::vector<data_size_t>::iterator SortedNodeBeginIterator(int node_id, int feature_id);
6873
std::vector<data_size_t>::iterator SortedNodeEndIterator(int node_id, int feature_id);
74+
SamplePredMapper* GetSamplePredMapper() {return sample_pred_mapper_.get();}
6975
SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();}
7076
UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();}
7177
SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();}
7278

7379
private:
80+
/*! \brief Mapper from observations to predicted values for every tree in a forest */
81+
std::unique_ptr<SamplePredMapper> sample_pred_mapper_;
7482
/*! \brief Mapper from observations to leaf node indices for every tree in a forest */
7583
std::unique_ptr<SampleNodeMapper> sample_node_mapper_;
7684
/*! \brief Data structure tracking / updating observations available in each node for every tree in a forest
@@ -88,6 +96,47 @@ class ForestTracker {
8896
int num_features_;
8997
};
9098

99+
/*! \brief Class storing sample-prediction map for each tree in an ensemble */
100+
class SamplePredMapper {
101+
public:
102+
SamplePredMapper(int num_trees, data_size_t num_observations) {
103+
num_trees_ = num_trees;
104+
num_observations_ = num_observations;
105+
// Initialize the vector of vectors of leaf indices for each tree
106+
tree_preds_.resize(num_trees_);
107+
for (int j = 0; j < num_trees_; j++) {
108+
tree_preds_[j].resize(num_observations_);
109+
}
110+
}
111+
112+
inline double GetPred(data_size_t sample_id, int tree_id) {
113+
CHECK_LT(sample_id, num_observations_);
114+
CHECK_LT(tree_id, num_trees_);
115+
return tree_preds_[tree_id][sample_id];
116+
}
117+
118+
inline void SetPred(data_size_t sample_id, int tree_id, double value) {
119+
CHECK_LT(sample_id, num_observations_);
120+
CHECK_LT(tree_id, num_trees_);
121+
tree_preds_[tree_id][sample_id] = value;
122+
}
123+
124+
inline int NumTrees() {return num_trees_;}
125+
126+
inline int NumObservations() {return num_observations_;}
127+
128+
inline void AssignAllSamplesToConstantPrediction(int tree_id, double value) {
129+
for (data_size_t i = 0; i < num_observations_; i++) {
130+
tree_preds_[tree_id][i] = value;
131+
}
132+
}
133+
134+
private:
135+
std::vector<std::vector<double>> tree_preds_;
136+
int num_trees_;
137+
data_size_t num_observations_;
138+
};
139+
91140
/*! \brief Class storing sample-node map for each tree in an ensemble */
92141
class SampleNodeMapper {
93142
public:

include/stochtree/tree_sampler.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,28 @@ static double ComputeMeanOutcome(ColumnVector& residual) {
162162
return total_outcome / static_cast<double>(n);
163163
}
164164

165-
static void UpdateResidualTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function<double(double, double)> op) {
165+
static void UpdateResidualTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
166166
data_size_t n = dataset.GetCovariates().rows();
167167
double pred_value;
168168
int32_t leaf_pred;
169169
double new_resid;
170170
for (data_size_t i = 0; i < n; i++) {
171-
leaf_pred = tracker.GetNodeId(i, tree_num);
172-
if (requires_basis) {
173-
pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
171+
if (tree_new) {
172+
// If the tree has been newly sampled or adjusted, we must rerun the prediction
173+
// method and update the SamplePredMapper stored in tracker
174+
leaf_pred = tracker.GetNodeId(i, tree_num);
175+
if (requires_basis) {
176+
pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
177+
} else {
178+
pred_value = tree->PredictFromNode(leaf_pred);
179+
}
180+
tracker.SetTreeSamplePrediction(i, tree_num, pred_value);
174181
} else {
175-
pred_value = tree->PredictFromNode(leaf_pred);
182+
// If the tree has not yet been modified via a sampling step,
183+
// we can query its prediction directly from the SamplePredMapper stored in tracker
184+
pred_value = tracker.GetTreeSamplePrediction(i, tree_num);
176185
}
186+
// Run op (either plus or minus) on the residual and the new prediction
177187
new_resid = op(residual.GetElement(i), pred_value);
178188
residual.SetElement(i, new_resid);
179189
}
@@ -210,7 +220,7 @@ class MCMCForestSampler {
210220
for (int i = 0; i < num_trees; i++) {
211221
// Add tree i's predictions back to the residual (thus, training a model on the "partial residual")
212222
tree = ensemble->GetTree(i);
213-
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_);
223+
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_, false);
214224

215225
// Sample tree i
216226
tree = ensemble->GetTree(i);
@@ -222,7 +232,7 @@ class MCMCForestSampler {
222232

223233
// Subtract tree i's predictions back out of the residual
224234
tree = ensemble->GetTree(i);
225-
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_);
235+
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_, true);
226236
}
227237
}
228238

@@ -477,7 +487,7 @@ class GFRForestSampler {
477487
for (int i = 0; i < num_trees; i++) {
478488
// Add tree i's predictions back to the residual (thus, training a model on the "partial residual")
479489
Tree* tree = ensemble->GetTree(i);
480-
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_);
490+
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), plus_op_, false);
481491

482492
// Reset the tree and sample trackers
483493
ensemble->ResetInitTree(i);
@@ -492,7 +502,7 @@ class GFRForestSampler {
492502
leaf_model.SampleLeafParameters(dataset, tracker, residual, tree, i, global_variance, gen);
493503

494504
// Subtract tree i's predictions back out of the residual
495-
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_);
505+
UpdateResidualTree(tracker, dataset, residual, tree, i, leaf_model.RequiresBasis(), minus_op_, true);
496506
}
497507
}
498508

src/partition_tracker.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
namespace StochTree {
1616

1717
ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector<FeatureType>& feature_types, int num_trees, int num_observations) {
18+
sample_pred_mapper_ = std::make_unique<SamplePredMapper>(num_trees, num_observations);
1819
sample_node_mapper_ = std::make_unique<SampleNodeMapper>(num_trees, num_observations);
1920
unsorted_node_sample_tracker_ = std::make_unique<UnsortedNodeSampleTracker>(num_observations, num_trees);
2021
presort_container_ = std::make_unique<FeaturePresortRootContainer>(covariates, feature_types);
@@ -69,6 +70,16 @@ void ForestTracker::AssignAllSamplesToRoot(int32_t tree_num) {
6970
sample_node_mapper_->AssignAllSamplesToRoot(tree_num);
7071
}
7172

73+
void ForestTracker::AssignAllSamplesToConstantPrediction(double value) {
74+
for (int i = 0; i < num_trees_; i++) {
75+
sample_pred_mapper_->AssignAllSamplesToConstantPrediction(i, value);
76+
}
77+
}
78+
79+
void ForestTracker::AssignAllSamplesToConstantPrediction(int32_t tree_num, double value) {
80+
sample_pred_mapper_->AssignAllSamplesToConstantPrediction(tree_num, value);
81+
}
82+
7283
void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted) {
7384
sample_node_mapper_->AddSplit(covariates, split, split_feature, tree_id, split_node_id, left_node_id, right_node_id);
7485
unsorted_node_sample_tracker_->PartitionTreeNode(covariates, tree_id, split_node_id, left_node_id, right_node_id, split_feature, split);
@@ -83,6 +94,14 @@ void ForestTracker::RemoveSplit(Eigen::MatrixXd& covariates, Tree* tree, int32_t
8394
// TODO: WARN if this is called from the GFR Tree Sampler
8495
}
8596

97+
double ForestTracker::GetTreeSamplePrediction(data_size_t sample_id, int tree_id) {
98+
return sample_pred_mapper_->GetPred(sample_id, tree_id);
99+
}
100+
101+
void ForestTracker::SetTreeSamplePrediction(data_size_t sample_id, int tree_id, double value) {
102+
sample_pred_mapper_->SetPred(sample_id, tree_id, value);
103+
}
104+
86105
FeatureUnsortedPartition::FeatureUnsortedPartition(data_size_t n) {
87106
indices_.resize(n);
88107
std::iota(indices_.begin(), indices_.end(), 0);

0 commit comments

Comments
 (0)