@@ -162,18 +162,28 @@ static double ComputeMeanOutcome(ColumnVector& residual) {
162
162
return total_outcome / static_cast <double >(n);
163
163
}
164
164
165
- static void UpdateResidualTree (ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function<double (double , double )> op) {
165
+ static void UpdateResidualTree (ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function<double (double , double )> op, bool tree_new ) {
166
166
data_size_t n = dataset.GetCovariates ().rows ();
167
167
double pred_value;
168
168
int32_t leaf_pred;
169
169
double new_resid;
170
170
for (data_size_t i = 0 ; i < n; i++) {
171
- leaf_pred = tracker.GetNodeId (i, tree_num);
172
- if (requires_basis) {
173
- pred_value = tree->PredictFromNode (leaf_pred, dataset.GetBasis (), i);
171
+ if (tree_new) {
172
+ // If the tree has been newly sampled or adjusted, we must rerun the prediction
173
+ // method and update the SamplePredMapper stored in tracker
174
+ leaf_pred = tracker.GetNodeId (i, tree_num);
175
+ if (requires_basis) {
176
+ pred_value = tree->PredictFromNode (leaf_pred, dataset.GetBasis (), i);
177
+ } else {
178
+ pred_value = tree->PredictFromNode (leaf_pred);
179
+ }
180
+ tracker.SetTreeSamplePrediction (i, tree_num, pred_value);
174
181
} else {
175
- pred_value = tree->PredictFromNode (leaf_pred);
182
+ // If the tree has not yet been modified via a sampling step,
183
+ // we can query its prediction directly from the SamplePredMapper stored in tracker
184
+ pred_value = tracker.GetTreeSamplePrediction (i, tree_num);
176
185
}
186
+ // Run op (either plus or minus) on the residual and the new prediction
177
187
new_resid = op (residual.GetElement (i), pred_value);
178
188
residual.SetElement (i, new_resid);
179
189
}
@@ -210,7 +220,7 @@ class MCMCForestSampler {
210
220
for (int i = 0 ; i < num_trees; i++) {
211
221
// Add tree i's predictions back to the residual (thus, training a model on the "partial residual")
212
222
tree = ensemble->GetTree (i);
213
- UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), plus_op_);
223
+ UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), plus_op_, false );
214
224
215
225
// Sample tree i
216
226
tree = ensemble->GetTree (i);
@@ -222,7 +232,7 @@ class MCMCForestSampler {
222
232
223
233
// Subtract tree i's predictions back out of the residual
224
234
tree = ensemble->GetTree (i);
225
- UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), minus_op_);
235
+ UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), minus_op_, true );
226
236
}
227
237
}
228
238
@@ -477,7 +487,7 @@ class GFRForestSampler {
477
487
for (int i = 0 ; i < num_trees; i++) {
478
488
// Add tree i's predictions back to the residual (thus, training a model on the "partial residual")
479
489
Tree* tree = ensemble->GetTree (i);
480
- UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), plus_op_);
490
+ UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), plus_op_, false );
481
491
482
492
// Reset the tree and sample trackers
483
493
ensemble->ResetInitTree (i);
@@ -492,7 +502,7 @@ class GFRForestSampler {
492
502
leaf_model.SampleLeafParameters (dataset, tracker, residual, tree, i, global_variance, gen);
493
503
494
504
// Subtract tree i's predictions back out of the residual
495
- UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), minus_op_);
505
+ UpdateResidualTree (tracker, dataset, residual, tree, i, leaf_model.RequiresBasis (), minus_op_, true );
496
506
}
497
507
}
498
508
0 commit comments