Skip to content

Commit 95421bc

Browse files
authored
Merge pull request #35 from StochasticTree/expand_initialization_api
Expand initialization api
2 parents 990c0a7 + 36fffa4 commit 95421bc

File tree

7 files changed

+113
-33
lines changed

7 files changed

+113
-33
lines changed

include/stochtree/container.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class ForestContainer {
2626
ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true);
2727
~ForestContainer() {}
2828

29+
void InitializeRoot(double leaf_value);
30+
void InitializeRoot(std::vector<double>& leaf_vector);
2931
void AddSamples(int num_samples);
3032
void CopyFromPreviousSample(int new_sample_id, int previous_sample_id);
3133
std::vector<double> Predict(ForestDataset& dataset);
@@ -39,8 +41,13 @@ class ForestContainer {
3941
inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();}
4042
inline int32_t OutputDimension() {return output_dimension_;}
4143
inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();}
44+
inline bool IsLeafConstant() {return is_leaf_constant_;}
4245
inline bool IsLeafConstant(int ensemble_num) {return forests_[ensemble_num]->IsLeafConstant();}
43-
46+
inline bool AllRoots(int ensemble_num) {return forests_[ensemble_num]->AllRoots();}
47+
inline void SetLeafValue(int ensemble_num, double leaf_value) {forests_[ensemble_num]->SetLeafValue(leaf_value);}
48+
inline void SetLeafVector(int ensemble_num, std::vector<double>& leaf_vector) {forests_[ensemble_num]->SetLeafVector(leaf_vector);}
49+
inline void IncrementSampleCount() {num_samples_++;}
50+
4451
void SaveToJsonFile(std::string filename) {
4552
nlohmann::json model_json = this->to_json();
4653
std::ofstream output_file(filename);

include/stochtree/ensemble.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,31 @@ class TreeEnsemble {
195195
return is_leaf_constant_;
196196
}
197197

198+
inline bool AllRoots() {
199+
for (int i = 0; i < num_trees_; i++) {
200+
if (!trees_[i]->IsRoot()) {
201+
return false;
202+
}
203+
}
204+
return true;
205+
}
206+
207+
inline void SetLeafValue(double leaf_value) {
208+
CHECK_EQ(output_dimension_, 1);
209+
for (int i = 0; i < num_trees_; i++) {
210+
CHECK(trees_[i]->IsRoot());
211+
trees_[i]->SetLeaf(0, leaf_value);
212+
}
213+
}
214+
215+
inline void SetLeafVector(std::vector<double>& leaf_vector) {
216+
CHECK_EQ(output_dimension_, leaf_vector.size());
217+
for (int i = 0; i < num_trees_; i++) {
218+
CHECK(trees_[i]->IsRoot());
219+
trees_[i]->SetLeafVector(0, leaf_vector);
220+
}
221+
}
222+
198223
/*! \brief Save to JSON */
199224
json to_json() {
200225
json result_obj;

include/stochtree/tree.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class Tree {
9595
void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, double left_value, double right_value);
9696
/*! \brief Expand a node based on a generic split rule */
9797
void ExpandNode(std::int32_t nid, int split_index, TreeSplit& split, std::vector<double> left_value_vector, std::vector<double> right_value_vector);
98+
99+
/*! \brief Whether or not a tree is a "stump" consisting of a single root node */
100+
inline bool IsRoot() {return leaves_.size() == 1;}
98101

99102
/*! \brief Save to JSON */
100103
json to_json();

include/stochtree/tree_sampler.h

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,31 @@ static double ComputeMeanOutcome(ColumnVector& residual) {
162162
return total_outcome / static_cast<double>(n);
163163
}
164164

165+
static void UpdateResidualEntireForest(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, TreeEnsemble* forest, bool requires_basis, std::function<double(double, double)> op) {
166+
data_size_t n = dataset.GetCovariates().rows();
167+
double tree_pred = 0.;
168+
double pred_value = 0.;
169+
double new_resid = 0.;
170+
int32_t leaf_pred;
171+
for (data_size_t i = 0; i < n; i++) {
172+
for (int j = 0; j < forest->NumTrees(); j++) {
173+
Tree* tree = forest->GetTree(j);
174+
leaf_pred = tracker.GetNodeId(i, j);
175+
if (requires_basis) {
176+
tree_pred += tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i);
177+
} else {
178+
tree_pred += tree->PredictFromNode(leaf_pred);
179+
}
180+
tracker.SetTreeSamplePrediction(i, j, tree_pred);
181+
pred_value += tree_pred;
182+
}
183+
184+
// Run op (either plus or minus) on the residual and the new prediction
185+
new_resid = op(residual.GetElement(i), pred_value);
186+
residual.SetElement(i, new_resid);
187+
}
188+
}
189+
165190
static void UpdateResidualTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function<double(double, double)> op, bool tree_new) {
166191
data_size_t n = dataset.GetCovariates().rows();
167192
double pred_value;
@@ -196,21 +221,27 @@ class MCMCForestSampler {
196221
~MCMCForestSampler() {}
197222

198223
void SampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
199-
ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double> variable_weights, double global_variance) {
200-
224+
ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double> variable_weights,
225+
double global_variance, bool pre_initialized = false) {
201226
// Previous number of samples
202227
int prev_num_samples = forests.NumSamples();
203-
// Add new forest to the container
204-
forests.AddSamples(1);
205228

206-
if (prev_num_samples == 0) {
229+
if ((prev_num_samples == 0) && (!pre_initialized)) {
230+
// Add new forest to the container
231+
forests.AddSamples(1);
232+
207233
// Set initial value for each leaf in the forest
208234
double root_pred = ComputeMeanOutcome(residual) / static_cast<double>(forests.NumTrees());
209235
TreeEnsemble* ensemble = forests.GetEnsemble(0);
210236
leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred);
211-
} else {
237+
} else if (prev_num_samples > 0) {
238+
// Add new forest to the container
239+
forests.AddSamples(1);
240+
212241
// Copy previous forest
213242
forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1);
243+
} else {
244+
forests.IncrementSampleCount();
214245
}
215246

216247
// Run the MCMC algorithm for each tree
@@ -462,23 +493,29 @@ class GFRForestSampler {
462493

463494
void SampleOneIter(ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset,
464495
ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector<double> variable_weights,
465-
double global_variance, std::vector<FeatureType>& feature_types) {
496+
double global_variance, std::vector<FeatureType>& feature_types, bool pre_initialized = false) {
466497
// Previous number of samples
467498
int prev_num_samples = forests.NumSamples();
468-
// Add new forest to the container
469-
forests.AddSamples(1);
470499

471-
if (prev_num_samples == 0) {
500+
if ((prev_num_samples == 0) && (!pre_initialized)) {
501+
// Add new forest to the container
502+
forests.AddSamples(1);
503+
472504
// Set initial value for each leaf in the forest
473505
double root_pred = ComputeMeanOutcome(residual) / static_cast<double>(forests.NumTrees());
474506
TreeEnsemble* ensemble = forests.GetEnsemble(0);
475507
leaf_model.SetEnsembleRootPredictedValue(dataset, ensemble, root_pred);
476-
} else {
508+
} else if (prev_num_samples > 0) {
509+
// Add new forest to the container
510+
forests.AddSamples(1);
511+
477512
// NOTE: only doing this for the simplicity of the partial residual step
478513
// We could alternatively "reach back" to the tree predictions from a previous
479514
// sample (whenever there is more than one sample). This is cleaner / quicker
480515
// to implement during this refactor.
481516
forests.CopyFromPreviousSample(prev_num_samples, prev_num_samples - 1);
517+
} else {
518+
forests.IncrementSampleCount();
482519
}
483520

484521
// Run the GFR algorithm for each tree

src/container.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,37 @@ ForestContainer::ForestContainer(int num_samples, int num_trees, int output_dime
2323
num_trees_ = num_trees;
2424
output_dimension_ = output_dimension;
2525
is_leaf_constant_ = is_leaf_constant;
26+
initialized_ = true;
2627
}
2728

2829
void ForestContainer::CopyFromPreviousSample(int new_sample_id, int previous_sample_id) {
2930
forests_[new_sample_id].reset(new TreeEnsemble(*forests_[previous_sample_id]));
3031
}
3132

33+
void ForestContainer::InitializeRoot(double leaf_value) {
34+
CHECK(initialized_);
35+
CHECK_EQ(num_samples_, 0);
36+
CHECK_EQ(forests_.size(), 0);
37+
forests_.resize(1);
38+
forests_[0].reset(new TreeEnsemble(num_trees_, output_dimension_, is_leaf_constant_));
39+
// NOTE: not setting num_samples = 1, since we are just initializing constant root
40+
// nodes and the forest still needs to be sampled by either MCMC or GFR
41+
num_samples_ = 0;
42+
SetLeafValue(0, leaf_value);
43+
}
44+
45+
void ForestContainer::InitializeRoot(std::vector<double>& leaf_vector) {
46+
CHECK(initialized_);
47+
CHECK_EQ(num_samples_, 0);
48+
CHECK_EQ(forests_.size(), 0);
49+
forests_.resize(1);
50+
forests_[0].reset(new TreeEnsemble(num_trees_, output_dimension_, is_leaf_constant_));
51+
// NOTE: not setting num_samples = 1, since we are just initializing constant root
52+
// nodes and the forest still needs to be sampled by either MCMC or GFR
53+
num_samples_ = 0;
54+
SetLeafVector(0, leaf_vector);
55+
}
56+
3257
void ForestContainer::AddSamples(int num_samples) {
3358
CHECK(initialized_);
3459
int total_new_samples = num_samples + num_samples_;

src/leaf_model.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,10 @@ void GaussianMultivariateRegressionLeafModel::SetEnsembleRootPredictedValue(Fore
691691
Log::Fatal("For multivariate leaf regression, outcomes should be centered / scaled so that the root coefficients can be initialized to 0");
692692
}
693693

694+
std::vector<double> root_pred_vector(ensemble->OutputDimension(), root_pred_value);
694695
for (int i = 0; i < num_trees; i++) {
695696
Tree* tree = ensemble->GetTree(i);
696-
tree->SetLeaf(0, root_pred_value);
697+
tree->SetLeafVector(0, root_pred_vector);
697698
}
698699
}
699700

src/tree.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,11 @@ constexpr std::int32_t Tree::kDeletedNodeMarker;
1919
constexpr std::int32_t Tree::kRoot;
2020

2121
std::int32_t Tree::NumLeaves() const {
22-
std::int32_t leaves { 0 };
23-
auto const& self = *this;
24-
this->WalkTree([&leaves, &self](std::int32_t nidx) {
25-
if (self.IsLeaf(nidx)) {
26-
leaves++;
27-
}
28-
return true;
29-
});
30-
return leaves;
22+
return leaves_.size();
3123
}
3224

3325
std::int32_t Tree::NumLeafParents() const {
34-
std::int32_t leaf_parents { 0 };
35-
auto const& self = *this;
36-
this->WalkTree([&leaf_parents, &self](std::int32_t nidx) {
37-
if (!self.IsLeaf(nidx)){
38-
if ((self.IsLeaf(self.LeftChild(nidx))) && (self.IsLeaf(self.RightChild(nidx)))){
39-
leaf_parents++;
40-
}
41-
}
42-
return true;
43-
});
44-
return leaf_parents;
26+
return leaf_parents_.size();
4527
}
4628

4729
std::int32_t Tree::NumSplitNodes() const {

0 commit comments

Comments
 (0)