Skip to content

Commit 596e960

Browse files
authored
Merge pull request #38 from StochasticTree/sparse_leaf_matrix
High-level procedures for working with leaf membership information (for e.g. kernels)
2 parents 44e3be1 + 27bf2dd commit 596e960

File tree

5 files changed

+485
-0
lines changed

5 files changed

+485
-0
lines changed

include/stochtree/ensemble.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,99 @@ class TreeEnsemble {
220220
}
221221
}
222222

223+
/*!
224+
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
225+
* observation in a ForestDataset. Internally, trees are stored as essentially
226+
* vectors of node information, and the leaves_ vector gives us node IDs for every
227+
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
228+
* which leaf number it is mapped to. Since the leaf numbers themselves
229+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
230+
* We compute this at the tree-level and coordinate this computation at the
231+
* ensemble level.
232+
*
233+
* Note: this assumes the creation of a vector of column indices of size
234+
* `dataset.NumObservations()` x `ensemble.NumTrees()`
235+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
236+
* \param output Vector of length num_trees*n which stores the leaf node prediction
237+
* \param num_trees Number of trees in an ensemble
238+
* \param n Size of dataset
239+
*/
240+
void PredictLeafIndicesInplace(ForestDataset* dataset, std::vector<int32_t>& output, int num_trees, data_size_t n) {
241+
PredictLeafIndicesInplace(dataset->GetCovariates(), output, num_trees, n);
242+
}
243+
244+
/*!
245+
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
246+
* observation in a ForestDataset. Internally, trees are stored as essentially
247+
* vectors of node information, and the leaves_ vector gives us node IDs for every
248+
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
249+
* which leaf number it is mapped to. Since the leaf numbers themselves
250+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
251+
* We compute this at the tree-level and coordinate this computation at the
252+
* ensemble level.
253+
*
254+
* Note: this assumes the creation of a vector of column indices of size
255+
* `dataset.NumObservations()` x `ensemble.NumTrees()`
256+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
257+
* \param output Vector of length num_trees*n which stores the leaf node prediction
258+
* \param num_trees Number of trees in an ensemble
259+
* \param n Size of dataset
260+
*/
261+
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
262+
CHECK_GE(output.size(), num_trees*n);
263+
int offset = 0;
264+
int max_leaf = 0;
265+
for (int j = 0; j < num_trees; j++) {
266+
auto &tree = *trees_[j];
267+
int num_leaves = tree.NumLeaves();
268+
tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
269+
offset += n;
270+
max_leaf += num_leaves;
271+
}
272+
}
273+
274+
/*!
275+
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
276+
* observation in a ForestDataset. Internally, trees are stored as essentially
277+
* vectors of node information, and the leaves_ vector gives us node IDs for every
278+
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
279+
* which leaf number it is mapped to. Since the leaf numbers themselves
280+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
281+
* We compute this at the tree-level and coordinate this computation at the
282+
* ensemble level.
283+
*
284+
* Note: this assumes the creation of a vector of column indices of size
285+
* `dataset.NumObservations()` x `ensemble.NumTrees()`
286+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
287+
* \param output Vector of length num_trees*n which stores the leaf node prediction
288+
* \param num_trees Number of trees in an ensemble
289+
* \param n Size of dataset
290+
*/
291+
void PredictLeafIndicesInplace(Eigen::MatrixXd& covariates, std::vector<int32_t>& output, int num_trees, data_size_t n) {
292+
CHECK_GE(output.size(), num_trees*n);
293+
int offset = 0;
294+
int max_leaf = 0;
295+
for (int j = 0; j < num_trees; j++) {
296+
auto &tree = *trees_[j];
297+
int num_leaves = tree.NumLeaves();
298+
tree.PredictLeafIndexInplace(covariates, output, offset, max_leaf);
299+
offset += n;
300+
max_leaf += num_leaves;
301+
}
302+
}
303+
304+
/*!
305+
* \brief Same as `PredictLeafIndicesInplace` but assumes responsibility for allocating and returning output vector.
306+
* \param ForestDataset Dataset with which to predict leaf indices from the tree
307+
*/
308+
std::vector<int32_t> PredictLeafIndices(ForestDataset* dataset) {
309+
int num_trees = num_trees_;
310+
data_size_t n = dataset->NumObservations();
311+
std::vector<int32_t> output(n*num_trees);
312+
PredictLeafIndicesInplace(dataset, output, num_trees, n);
313+
return output;
314+
}
315+
223316
/*! \brief Save to JSON */
224317
json to_json() {
225318
json result_obj;

include/stochtree/kernel.h

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/*!
2+
* Copyright (c) 2024 stochtree authors. All rights reserved.
3+
* Licensed under the MIT License. See LICENSE file in the project root for license information.
4+
*/
5+
#ifndef STOCHTREE_TREE_KERNEL_H_
6+
#define STOCHTREE_TREE_KERNEL_H_
7+
8+
#include <stochtree/data.h>
9+
#include <stochtree/ensemble.h>
10+
#include <Eigen/Dense>
11+
#include <Eigen/Sparse>
12+
13+
#include <cmath>
14+
#include <map>
15+
#include <memory>
16+
#include <random>
17+
#include <set>
18+
#include <string>
19+
#include <type_traits>
20+
#include <vector>
21+
22+
namespace StochTree {
23+
24+
typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> KernelMatrixType;
25+
26+
class ForestKernel {
27+
public:
28+
ForestKernel() {}
29+
~ForestKernel() {}
30+
31+
void ComputeLeafIndices(Eigen::MatrixXd& covariates, TreeEnsemble& forest) {
32+
num_train_observations_ = covariates.rows();
33+
num_trees_ = forest.NumTrees();
34+
train_leaf_index_vector_.resize(num_train_observations_*num_trees_);
35+
forest.PredictLeafIndicesInplace(covariates, train_leaf_index_vector_, num_trees_, num_train_observations_);
36+
int max_cols = *std::max_element(train_leaf_index_vector_.begin(), train_leaf_index_vector_.end());
37+
train_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_train_observations_,max_cols+1);
38+
int col_num;
39+
for (data_size_t i = 0; i < num_train_observations_; i++) {
40+
for (int j = 0; j < num_trees_; j++) {
41+
col_num = train_leaf_index_vector_.at(j*num_train_observations_ + i);
42+
train_leaf_index_matrix_.insert(i,col_num) = 1.;
43+
}
44+
}
45+
train_leaf_indices_stored_ = true;
46+
}
47+
48+
void ComputeLeafIndices(KernelMatrixType& covariates, TreeEnsemble& forest) {
49+
num_train_observations_ = covariates.rows();
50+
num_trees_ = forest.NumTrees();
51+
train_leaf_index_vector_.resize(num_train_observations_*num_trees_);
52+
forest.PredictLeafIndicesInplace(covariates, train_leaf_index_vector_, num_trees_, num_train_observations_);
53+
int max_cols = *std::max_element(train_leaf_index_vector_.begin(), train_leaf_index_vector_.end());
54+
train_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_train_observations_,max_cols+1);
55+
int col_num;
56+
for (data_size_t i = 0; i < num_train_observations_; i++) {
57+
for (int j = 0; j < num_trees_; j++) {
58+
col_num = train_leaf_index_vector_.at(j*num_train_observations_ + i);
59+
train_leaf_index_matrix_.insert(i,col_num) = 1.;
60+
}
61+
}
62+
train_leaf_indices_stored_ = true;
63+
}
64+
65+
void ComputeLeafIndices(Eigen::MatrixXd& covariates_train, Eigen::MatrixXd& covariates_test, TreeEnsemble& forest) {
66+
CHECK_EQ(covariates_train.cols(), covariates_test.cols());
67+
num_train_observations_ = covariates_train.rows();
68+
num_test_observations_ = covariates_test.rows();
69+
num_trees_ = forest.NumTrees();
70+
train_leaf_index_vector_.resize(num_train_observations_*num_trees_);
71+
test_leaf_index_vector_.resize(num_test_observations_*num_trees_);
72+
forest.PredictLeafIndicesInplace(covariates_train, train_leaf_index_vector_, num_trees_, num_train_observations_);
73+
forest.PredictLeafIndicesInplace(covariates_test, test_leaf_index_vector_, num_trees_, num_test_observations_);
74+
int max_cols_train = *std::max_element(train_leaf_index_vector_.begin(), train_leaf_index_vector_.end());
75+
int max_cols_test = *std::max_element(test_leaf_index_vector_.begin(), test_leaf_index_vector_.end());
76+
int max_cols = max_cols_train > max_cols_test ? max_cols_train : max_cols_test;
77+
train_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_train_observations_,max_cols+1);
78+
test_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_test_observations_,max_cols+1);
79+
int col_num;
80+
for (data_size_t i = 0; i < num_train_observations_; i++) {
81+
for (int j = 0; j < num_trees_; j++) {
82+
col_num = train_leaf_index_vector_.at(j*num_train_observations_ + i);
83+
train_leaf_index_matrix_.insert(i,col_num) = 1.;
84+
}
85+
}
86+
train_leaf_indices_stored_ = true;
87+
for (data_size_t i = 0; i < num_test_observations_; i++) {
88+
for (int j = 0; j < num_trees_; j++) {
89+
col_num = test_leaf_index_vector_.at(j*num_test_observations_ + i);
90+
test_leaf_index_matrix_.insert(i,col_num) = 1.;
91+
}
92+
}
93+
test_leaf_indices_stored_ = true;
94+
}
95+
96+
void ComputeLeafIndices(KernelMatrixType& covariates_train, KernelMatrixType& covariates_test, TreeEnsemble& forest) {
97+
CHECK_EQ(covariates_train.cols(), covariates_test.cols());
98+
num_train_observations_ = covariates_train.rows();
99+
num_test_observations_ = covariates_test.rows();
100+
num_trees_ = forest.NumTrees();
101+
train_leaf_index_vector_.resize(num_train_observations_*num_trees_);
102+
test_leaf_index_vector_.resize(num_test_observations_*num_trees_);
103+
forest.PredictLeafIndicesInplace(covariates_train, train_leaf_index_vector_, num_trees_, num_train_observations_);
104+
forest.PredictLeafIndicesInplace(covariates_test, test_leaf_index_vector_, num_trees_, num_test_observations_);
105+
int max_cols_train = *std::max_element(train_leaf_index_vector_.begin(), train_leaf_index_vector_.end());
106+
int max_cols_test = *std::max_element(test_leaf_index_vector_.begin(), test_leaf_index_vector_.end());
107+
int max_cols = max_cols_train > max_cols_test ? max_cols_train : max_cols_test;
108+
train_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_train_observations_,max_cols+1);
109+
test_leaf_index_matrix_ = Eigen::SparseMatrix<double>(num_test_observations_,max_cols+1);
110+
int col_num;
111+
for (data_size_t i = 0; i < num_train_observations_; i++) {
112+
for (int j = 0; j < num_trees_; j++) {
113+
col_num = train_leaf_index_vector_.at(j*num_train_observations_ + i);
114+
train_leaf_index_matrix_.insert(i,col_num) = 1.;
115+
}
116+
}
117+
train_leaf_indices_stored_ = true;
118+
for (data_size_t i = 0; i < num_test_observations_; i++) {
119+
for (int j = 0; j < num_trees_; j++) {
120+
col_num = test_leaf_index_vector_.at(j*num_test_observations_ + i);
121+
test_leaf_index_matrix_.insert(i,col_num) = 1.;
122+
}
123+
}
124+
test_leaf_indices_stored_ = true;
125+
}
126+
127+
void ComputeKernel(Eigen::MatrixXd& covariates, TreeEnsemble& forest) {
128+
ComputeLeafIndices(covariates, forest);
129+
tree_kernel_train_ = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
130+
train_kernel_stored_ = true;
131+
}
132+
133+
void ComputeKernel(KernelMatrixType& covariates, TreeEnsemble& forest) {
134+
ComputeLeafIndices(covariates, forest);
135+
tree_kernel_train_ = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
136+
train_kernel_stored_ = true;
137+
}
138+
139+
void ComputeKernelExternal(Eigen::MatrixXd& covariates, TreeEnsemble& forest, KernelMatrixType& kernel_map) {
140+
ComputeLeafIndices(covariates, forest);
141+
kernel_map = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
142+
}
143+
144+
void ComputeKernelExternal(KernelMatrixType& covariates, TreeEnsemble& forest, KernelMatrixType& kernel_map) {
145+
ComputeLeafIndices(covariates, forest);
146+
kernel_map = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
147+
}
148+
149+
void ComputeKernel(Eigen::MatrixXd& covariates_train, Eigen::MatrixXd& covariates_test, TreeEnsemble& forest) {
150+
ComputeLeafIndices(covariates_train, covariates_test, forest);
151+
tree_kernel_train_ = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
152+
train_kernel_stored_ = true;
153+
tree_kernel_test_train_ = test_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
154+
tree_kernel_test_ = test_leaf_index_matrix_ * test_leaf_index_matrix_.transpose();
155+
test_kernel_stored_ = true;
156+
}
157+
158+
void ComputeKernel(KernelMatrixType& covariates_train, KernelMatrixType& covariates_test, TreeEnsemble& forest) {
159+
ComputeLeafIndices(covariates_train, covariates_test, forest);
160+
tree_kernel_train_ = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
161+
train_kernel_stored_ = true;
162+
tree_kernel_test_train_ = test_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
163+
tree_kernel_test_ = test_leaf_index_matrix_ * test_leaf_index_matrix_.transpose();
164+
test_kernel_stored_ = true;
165+
}
166+
167+
void ComputeKernelExternal(Eigen::MatrixXd& covariates_train, Eigen::MatrixXd& covariates_test, TreeEnsemble& forest,
168+
KernelMatrixType& kernel_map_train, KernelMatrixType& kernel_map_test_train, KernelMatrixType& kernel_map_test) {
169+
ComputeLeafIndices(covariates_train, covariates_test, forest);
170+
kernel_map_train = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
171+
kernel_map_test_train = test_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
172+
kernel_map_test = test_leaf_index_matrix_ * test_leaf_index_matrix_.transpose();
173+
}
174+
175+
void ComputeKernelExternal(KernelMatrixType& covariates_train, KernelMatrixType& covariates_test, TreeEnsemble& forest,
176+
KernelMatrixType& kernel_map_train, KernelMatrixType& kernel_map_test_train, KernelMatrixType& kernel_map_test) {
177+
ComputeLeafIndices(covariates_train, covariates_test, forest);
178+
kernel_map_train = train_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
179+
kernel_map_test_train = test_leaf_index_matrix_ * train_leaf_index_matrix_.transpose();
180+
kernel_map_test = test_leaf_index_matrix_ * test_leaf_index_matrix_.transpose();
181+
}
182+
183+
std::vector<int32_t>& GetTrainLeafIndices() {
184+
CHECK(train_leaf_indices_stored_);
185+
return train_leaf_index_vector_;
186+
}
187+
188+
std::vector<int32_t>& GetTestLeafIndices() {
189+
CHECK(test_leaf_indices_stored_);
190+
return test_leaf_index_vector_;
191+
}
192+
193+
Eigen::MatrixXd& GetTrainKernel() {
194+
CHECK(train_kernel_stored_);
195+
return tree_kernel_train_;
196+
}
197+
198+
Eigen::MatrixXd& GetTestTrainKernel() {
199+
CHECK(test_kernel_stored_);
200+
return tree_kernel_test_train_;
201+
}
202+
203+
Eigen::MatrixXd& GetTestKernel() {
204+
CHECK(test_kernel_stored_);
205+
return tree_kernel_test_;
206+
}
207+
208+
data_size_t NumTrainObservations() {
209+
return num_train_observations_;
210+
}
211+
212+
data_size_t NumTestObservations() {
213+
return num_test_observations_;
214+
}
215+
216+
int NumTrees() {
217+
return num_trees_;
218+
}
219+
220+
bool HasTrainLeafIndices() {
221+
return train_leaf_indices_stored_;
222+
}
223+
224+
bool HasTestLeafIndices() {
225+
return test_leaf_indices_stored_;
226+
}
227+
228+
bool HasTrainKernel() {
229+
return train_kernel_stored_;
230+
}
231+
232+
bool HasTestKernel() {
233+
return test_kernel_stored_;
234+
}
235+
236+
private:
237+
data_size_t num_train_observations_{0};
238+
data_size_t num_test_observations_{0};
239+
int num_trees_{0};
240+
std::vector<int32_t> train_leaf_index_vector_;
241+
std::vector<int32_t> test_leaf_index_vector_;
242+
Eigen::SparseMatrix<double> train_leaf_index_matrix_;
243+
Eigen::SparseMatrix<double> test_leaf_index_matrix_;
244+
Eigen::MatrixXd tree_kernel_train_;
245+
Eigen::MatrixXd tree_kernel_test_train_;
246+
Eigen::MatrixXd tree_kernel_test_;
247+
bool train_leaf_indices_stored_{false};
248+
bool test_leaf_indices_stored_{false};
249+
bool train_kernel_stored_{false};
250+
bool test_kernel_stored_{false};
251+
};
252+
253+
} // namespace StochTree
254+
255+
#endif // STOCHTREE_TREE_KERNEL_H_

0 commit comments

Comments
 (0)