Skip to content

Commit 795b1e8

Browse files
authored
Merge pull request #172 from StochasticTree/merge-forests
Merge multiple forests into a single forest and perform arithmetic operations on forests
2 parents f55bbb4 + a22ca1c commit 795b1e8

29 files changed

+1710
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cpp_docs/doxyoutput/html
1212
cpp_docs/doxyoutput/xml
1313
cpp_docs/doxyoutput/latex
1414
stochtree_cran
15+
*.trace
1516

1617
## R gitignore
1718

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,16 @@ export(saveBCFModelToJsonString)
6666
export(savePreprocessorToJsonString)
6767
importFrom(R6,R6Class)
6868
importFrom(stats,coef)
69+
importFrom(stats,dnorm)
6970
importFrom(stats,lm)
7071
importFrom(stats,model.matrix)
72+
importFrom(stats,pnorm)
7173
importFrom(stats,predict)
7274
importFrom(stats,qgamma)
75+
importFrom(stats,qnorm)
7376
importFrom(stats,resid)
7477
importFrom(stats,rnorm)
78+
importFrom(stats,runif)
7579
importFrom(stats,sd)
7680
importFrom(stats,sigma)
7781
importFrom(stats,var)

R/cpp11.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,34 @@ forest_container_from_json_string_cpp <- function(json_string, forest_label) {
252252
.Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label)
253253
}
254254

255+
forest_merge_cpp <- function(inbound_forest_ptr, outbound_forest_ptr) {
256+
invisible(.Call(`_stochtree_forest_merge_cpp`, inbound_forest_ptr, outbound_forest_ptr))
257+
}
258+
259+
forest_add_constant_cpp <- function(forest_ptr, constant_value) {
260+
invisible(.Call(`_stochtree_forest_add_constant_cpp`, forest_ptr, constant_value))
261+
}
262+
263+
forest_multiply_constant_cpp <- function(forest_ptr, constant_multiple) {
264+
invisible(.Call(`_stochtree_forest_multiply_constant_cpp`, forest_ptr, constant_multiple))
265+
}
266+
255267
forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) {
256268
invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label))
257269
}
258270

271+
combine_forests_forest_container_cpp <- function(forest_samples, forest_inds) {
272+
invisible(.Call(`_stochtree_combine_forests_forest_container_cpp`, forest_samples, forest_inds))
273+
}
274+
275+
add_to_forest_forest_container_cpp <- function(forest_samples, forest_index, constant_value) {
276+
invisible(.Call(`_stochtree_add_to_forest_forest_container_cpp`, forest_samples, forest_index, constant_value))
277+
}
278+
279+
multiply_forest_forest_container_cpp <- function(forest_samples, forest_index, constant_multiple) {
280+
invisible(.Call(`_stochtree_multiply_forest_forest_container_cpp`, forest_samples, forest_index, constant_multiple))
281+
}
282+
259283
num_samples_forest_container_cpp <- function(forest_samples) {
260284
.Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples)
261285
}

R/forest.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,73 @@ ForestSamples <- R6::R6Class(
2222
self$forest_container_ptr <- forest_container_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
2323
},
2424

25+
#' @description
26+
#' Collapse forests in this container by a pre-specified batch size.
27+
#' For example, if we have a container of twenty 10-tree forests, and we
28+
#' specify a `batch_size` of 5, then this method will yield four 50-tree
29+
#' forests. "Excess" forests remaining after the size of a forest container
30+
#' is divided by `batch_size` will be pruned from the beginning of the
31+
#' container (i.e. earlier sampled forests will be deleted). This method
32+
#' has no effect if `batch_size` is larger than the number of forests
33+
#' in a container.
34+
#' @param batch_size Number of forests to be collapsed into a single forest
35+
collapse = function(batch_size) {
36+
container_size <- self$num_samples()
37+
if ((batch_size <= container_size) && (batch_size > 1)) {
38+
reverse_container_inds <- seq(container_size, 1, -1)
39+
num_clean_batches <- container_size %/% batch_size
40+
batch_inds <- (reverse_container_inds - (container_size - (container_size %/% num_clean_batches) * num_clean_batches) - 1) %/% batch_size
41+
for (batch_ind in unique(batch_inds[batch_inds >= 0])) {
42+
merge_forest_inds <- sort(reverse_container_inds[batch_inds == batch_ind] - 1)
43+
num_merge_forests <- length(merge_forest_inds)
44+
self$combine_forests(merge_forest_inds)
45+
for (i in num_merge_forests:2) {
46+
self$delete_sample(merge_forest_inds[i])
47+
}
48+
forest_scale_factor <- 1.0 / num_merge_forests
49+
self$multiply_forest(merge_forest_inds[1], forest_scale_factor)
50+
}
51+
if (min(batch_inds) < 0) {
52+
delete_forest_inds <- sort(reverse_container_inds[batch_inds < 0] - 1)
53+
for (i in length(delete_forest_inds):1) {
54+
self$delete_sample(delete_forest_inds[i])
55+
}
56+
}
57+
}
58+
},
59+
60+
#' @description
61+
#' Merge specified forests into a single forest
62+
#' @param forest_inds Indices of forests to be combined (0-indexed)
63+
combine_forests = function(forest_inds) {
64+
stopifnot(max(forest_inds) < self$num_samples())
65+
stopifnot(min(forest_inds) >= 0)
66+
stopifnot(length(forest_inds) > 1)
67+
stopifnot(all(as.integer(forest_inds) == forest_inds))
68+
forest_inds_sorted <- as.integer(sort(forest_inds))
69+
combine_forests_forest_container_cpp(self$forest_container_ptr, forest_inds_sorted)
70+
},
71+
72+
#' @description
73+
#' Add a constant value to every leaf of every tree of a given forest
74+
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
75+
#' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
76+
add_to_forest = function(forest_index, constant_value) {
77+
stopifnot(forest_index < self$num_samples())
78+
stopifnot(forest_index >= 0)
79+
add_to_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_value)
80+
},
81+
82+
#' @description
83+
#' Multiply every leaf of every tree of a given forest by constant value
84+
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
85+
#' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
86+
multiply_forest = function(forest_index, constant_multiple) {
87+
stopifnot(forest_index < self$num_samples())
88+
stopifnot(forest_index >= 0)
89+
multiply_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_multiple)
90+
},
91+
2592
#' @description
2693
#' Create a new `ForestContainer` object from a json object
2794
#' @param json_object Object of class `CppJson`
@@ -573,6 +640,30 @@ Forest <- R6::R6Class(
573640
self$internal_forest_is_empty <- TRUE
574641
},
575642

643+
#' @description
644+
#' Create a larger forest by merging the trees of this forest with those of another forest
645+
#' @param forest Forest to be merged into this forest
646+
merge_forest = function(forest) {
647+
stopifnot(self$leaf_dimension() == forest$leaf_dimension())
648+
stopifnot(self$is_constant_leaf() == forest$is_constant_leaf())
649+
stopifnot(self$is_exponentiated() == forest$is_exponentiated())
650+
forest_merge_cpp(self$forest_ptr, forest$forest_ptr)
651+
},
652+
653+
#' @description
654+
#' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
655+
#' @param constant_value Value that will be added to every leaf of every tree
656+
add_constant = function(constant_value) {
657+
forest_add_constant_cpp(self$forest_ptr, constant_value)
658+
},
659+
660+
#' @description
661+
#' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves.
662+
#' @param constant_multiple Value that will be multiplied by every leaf of every tree
663+
multiply_constant = function(constant_multiple) {
664+
forest_multiply_constant_cpp(self$forest_ptr, constant_multiple)
665+
},
666+
576667
#' @description
577668
#' Predict forest on every sample in `forest_dataset`
578669
#' @param forest_dataset `ForestDataset` R class
@@ -694,7 +785,7 @@ Forest <- R6::R6Class(
694785
#' Return constant leaf status of trees in a `Forest` object
695786
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
696787
is_constant_leaf = function() {
697-
return(is_constant_leaf_active_forest_cpp(self$forest_ptr))
788+
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
698789
},
699790

700791
#' @description

include/stochtree/container.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,33 @@ class ForestContainer {
4747
*/
4848
ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
4949
~ForestContainer() {}
50+
/*!
51+
* \brief Combine two forests into a single forest by merging their trees
52+
*
53+
* \param inbound_forest_index Index of the forest that will be appended to
54+
* \param outbound_forest_index Index of the forest that will be appended
55+
*/
56+
void MergeForests(int inbound_forest_index, int outbound_forest_index) {
57+
forests_[inbound_forest_index]->MergeForest(*forests_[outbound_forest_index]);
58+
}
59+
/*!
60+
* \brief Add a constant value to every leaf of every tree of a specified forest
61+
*
62+
* \param forest_index Index of forest whose leaves will be modified
63+
* \param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
64+
*/
65+
void AddToForest(int forest_index, double constant_value) {
66+
forests_[forest_index]->AddValueToLeaves(constant_value);
67+
}
68+
/*!
69+
* \brief Multiply every leaf of every tree of a specified forest by a constant value
70+
*
71+
* \param forest_index Index of forest whose leaves will be modified
72+
* \param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
73+
*/
74+
void MultiplyForest(int forest_index, double constant_multiple) {
75+
forests_[forest_index]->MultiplyLeavesByValue(constant_multiple);
76+
}
5077
/*!
5178
* \brief Remove a forest from a container of forest samples and delete the corresponding object, freeing its memory.
5279
*

include/stochtree/ensemble.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,54 @@ class TreeEnsemble {
8383

8484
~TreeEnsemble() {}
8585

86+
/*!
87+
* \brief Combine two forests into a single forest by merging their trees
88+
*
89+
* \param ensemble Reference to another `TreeEnsemble` that will be merged into the current ensemble
90+
*/
91+
void MergeForest(TreeEnsemble& ensemble) {
92+
// Unpack ensemble configurations
93+
int old_num_trees = num_trees_;
94+
num_trees_ += ensemble.num_trees_;
95+
CHECK_EQ(output_dimension_, ensemble.output_dimension_);
96+
CHECK_EQ(is_leaf_constant_, ensemble.is_leaf_constant_);
97+
CHECK_EQ(is_exponentiated_, ensemble.is_exponentiated_);
98+
// Resize tree vector and reset new trees
99+
trees_.resize(num_trees_);
100+
for (int i = old_num_trees; i < num_trees_; i++) {
101+
trees_[i].reset(new Tree());
102+
}
103+
// Clone trees in the input ensemble
104+
for (int j = 0; j < ensemble.num_trees_; j++) {
105+
Tree* tree = ensemble.GetTree(j);
106+
this->CloneFromExistingTree(old_num_trees + j, tree);
107+
}
108+
}
109+
110+
/*!
111+
* \brief Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
112+
*
113+
* \param constant_value Value that will be added to every leaf of every tree
114+
*/
115+
void AddValueToLeaves(double constant_value) {
116+
for (int j = 0; j < num_trees_; j++) {
117+
Tree* tree = GetTree(j);
118+
tree->AddValueToLeaves(constant_value);
119+
}
120+
}
121+
122+
/*!
123+
* \brief Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves.
124+
*
125+
* \param constant_multiple Value that will be multiplied by every leaf of every tree
126+
*/
127+
void MultiplyLeavesByValue(double constant_multiple) {
128+
for (int j = 0; j < num_trees_; j++) {
129+
Tree* tree = GetTree(j);
130+
tree->MultiplyLeavesByValue(constant_multiple);
131+
}
132+
}
133+
86134
/*!
87135
* \brief Return a pointer to a tree in the forest
88136
*

include/stochtree/tree.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,40 @@ class Tree {
201201
this->ChangeToLeaf(nid, value_vector);
202202
}
203203

204+
/*!
205+
* \brief Add a constant value to every leaf of a tree. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
206+
*
207+
* \param constant_value Value that will be added to every leaf of a tree
208+
*/
209+
void AddValueToLeaves(double constant_value) {
210+
if (output_dimension_ == 1) {
211+
for (int j = 0; j < leaf_value_.size(); j++) {
212+
leaf_value_[j] += constant_value;
213+
}
214+
} else {
215+
for (int j = 0; j < leaf_vector_.size(); j++) {
216+
leaf_vector_[j] += constant_value;
217+
}
218+
}
219+
}
220+
221+
/*!
222+
* \brief Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional, `constant_value` will be multiplied through every dimension of the leaves.
223+
*
224+
* \param constant_multiple Value that will be multiplied by every leaf of a tree
225+
*/
226+
void MultiplyLeavesByValue(double constant_multiple) {
227+
if (output_dimension_ == 1) {
228+
for (int j = 0; j < leaf_value_.size(); j++) {
229+
leaf_value_[j] *= constant_multiple;
230+
}
231+
} else {
232+
for (int j = 0; j < leaf_vector_.size(); j++) {
233+
leaf_vector_[j] *= constant_multiple;
234+
}
235+
}
236+
}
237+
204238
/*!
205239
* \brief Iterate through all nodes in this tree.
206240
*

man/Forest.Rd

Lines changed: 54 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)