Skip to content

Commit 0642930

Browse files
committed
Initial working version of leaf index sparse matrix representation
1 parent 95421bc commit 0642930

File tree

4 files changed

+96
-0
lines changed

4 files changed

+96
-0
lines changed

include/stochtree/ensemble.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,45 @@ class TreeEnsemble {
220220
}
221221
}
222222

223+
/*!
224+
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
225+
* observation in a ForestDataset. Internally, trees are stored as essentially
226+
* vectors of node information, and the leaves_ vector gives us node IDs for every
227+
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
228+
* which leaf number it is mapped to. Since the leaf numbers themselves
229+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
230+
* We compute this at the tree-level and coordinate this computation at the
231+
* ensemble level.
232+
*
233+
* Note: this assumes the creation of a vector of column indices of size
234+
* `dataset.NumObservations()` x `ensemble.NumTrees()`
235+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
236+
* \param output Vector of length num_trees*n which stores the leaf node prediction
237+
* \param num_trees Number of trees in an ensemble
238+
* \param n Size of dataset
239+
*/
240+
void PredictLeafIndicesInplace(ForestDataset* dataset, std::vector<int32_t>& output, int num_trees, data_size_t n) {
241+
CHECK_GE(output.size(), num_trees*n);
242+
int offset = 0;
243+
for (int j = 0; j < num_trees; j++) {
244+
auto &tree = *trees_[j];
245+
tree.PredictLeafIndexInplace(dataset, output, offset);
246+
offset += n;
247+
}
248+
}
249+
250+
/*!
251+
* \brief Same as `PredictLeafIndicesInplace` but assumes responsibility for allocating and returning output vector.
252+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
253+
*/
254+
std::vector<int32_t> PredictLeafIndices(ForestDataset* dataset) {
255+
int num_trees = num_trees_;
256+
data_size_t n = dataset->NumObservations();
257+
std::vector<int32_t> output(n*num_trees);
258+
PredictLeafIndicesInplace(dataset, output, num_trees, n);
259+
return output;
260+
}
261+
223262
/*! \brief Save to JSON */
224263
json to_json() {
225264
json result_obj;

include/stochtree/tree.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define STOCHTREE_TREE_H_
88

99
#include <nlohmann/json.hpp>
10+
#include <stochtree/data.h>
1011
#include <stochtree/log.h>
1112
#include <stochtree/meta.h>
1213
#include <Eigen/Dense>
@@ -586,6 +587,25 @@ class Tree {
586587
*/
587588
void SetLeafVector(std::int32_t nid, std::vector<double> const& leaf_vector);
588589

590+
/*!
591+
* \brief Obtain a 0-based leaf index for each observation in a ForestDataset.
592+
* Internally, trees are stored as essentially vectors of node information,
593+
* and the leaves_ vector gives us node IDs for every leaf in the tree.
594+
* Here, we would like to know, for every observation in a dataset,
595+
* which leaf number it is mapped to. Since the leaf numbers themselves
596+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
597+
*
598+
* Note: this is a tree-level helper function for an ensemble-level function.
599+
* It assumes the creation of:
600+
* (a) a vector of column indices of size `dataset.NumObservations()` x `ensemble.NumTrees()`, and
601+
* (b) a running counter of the number of tree-observations already indexed in the ensemble
602+
* (used as offsets for the leaf number computed and returned here)
603+
* Users running this function for a single tree may simply pre-allocate an output vector as
604+
* std::vector<int32_t> output(dataset->NumObservations()) and set the offset to 0.
605+
* \param dataset Dataset with which to predict leaf indices from the tree
606+
*/
607+
void PredictLeafIndexInplace(ForestDataset* dataset, std::vector<int32_t>& output, int32_t offset);
608+
589609
// Node info
590610
std::vector<TreeNodeType> node_type_;
591611
std::vector<std::int32_t> parent_;

src/tree.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,20 @@ void Tree::SetLeafVector(std::int32_t nid, std::vector<double> const& node_leaf_
420420
node_type_.at(nid) = TreeNodeType::kLeafNode;
421421
}
422422

423+
void Tree::PredictLeafIndexInplace(ForestDataset* dataset, std::vector<int32_t>& output, int32_t offset) {
424+
int n = dataset->NumObservations();
425+
CHECK_GE(output.size(), offset + n);
426+
std::map<int32_t,int32_t> renumber_map;
427+
for (int i = 0; i < leaves_.size(); i++) {
428+
renumber_map.insert({leaves_[i], i});
429+
}
430+
int32_t node_id;
431+
for (int i = 0; i < n; i++) {
432+
node_id = EvaluateTree(*this, dataset->GetCovariates(), i);
433+
output.at(offset + i) = renumber_map.at(node_id);
434+
}
435+
}
436+
423437
void TreeNodeVectorsToJson(json& obj, Tree* tree) {
424438
// Initialize a map with names of the node vectors and empty json arrays
425439
std::map<std::string, json> tree_array_map;

test/test_tree.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
#include <gtest/gtest.h>
66
#include <testutils.h>
7+
#include <stochtree/data.h>
78
#include <stochtree/log.h>
89
#include <stochtree/tree.h>
910
#include <iostream>
@@ -153,3 +154,25 @@ TEST(Tree, MultivariateTreeCategoricalSplitConstruction) {
153154
ASSERT_TRUE(tree.IsLeaf(1));
154155
ASSERT_TRUE(tree.IsLeaf(2));
155156
}
157+
158+
TEST(Tree, SparseLeafRepresentation) {
159+
// Construct small tree
160+
StochTree::Tree tree;
161+
tree.Init(1);
162+
tree.ExpandNode(0, 0, 0.5, 0., 0.);
163+
164+
// Load test data
165+
StochTree::TestUtils::TestDataset test_dataset;
166+
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
167+
168+
// Construct datasets
169+
int n = test_dataset.n;
170+
StochTree::ForestDataset dataset = StochTree::ForestDataset();
171+
dataset.AddCovariates(test_dataset.covariates.data(), n, test_dataset.x_cols, test_dataset.row_major);
172+
173+
// Predict leaf indices of each observation in `dataset`
174+
std::vector<int32_t> leaf_index_preds(n);
175+
tree.PredictLeafIndexInplace(&dataset, leaf_index_preds, 0);
176+
std::vector<int32_t> leaf_index_expected{1,1,0,1,1,1,1,1,0,0};
177+
ASSERT_EQ(leaf_index_expected, leaf_index_preds);
178+
}

0 commit comments

Comments
 (0)