Skip to content

Commit 742ba57

Browse files
committed
Added forest container methods to python interface as well
1 parent 6b516f3 commit 742ba57

File tree

6 files changed

+158
-6
lines changed

6 files changed

+158
-6
lines changed

R/forest.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ForestSamples <- R6::R6Class(
3636

3737
#' @description
3838
#' Add a constant value to every leaf of every tree of a given forest
39-
#' @param forest_index Index of forest whos leaves will be modified (0-indexed)
39+
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
4040
#' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
4141
add_to_forest = function(forest_index, constant_value) {
4242
stopifnot(forest_index < self$num_samples())
@@ -46,7 +46,7 @@ ForestSamples <- R6::R6Class(
4646

4747
#' @description
4848
#' Multiply every leaf of every tree of a given forest by constant value
49-
#' @param forest_index Index of forest whos leaves will be modified (0-indexed)
49+
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
5050
#' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
5151
multiply_forest = function(forest_index, constant_multiple) {
5252
stopifnot(forest_index < self$num_samples())

include/stochtree/container.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ForestContainer {
5959
/*!
6060
* \brief Add a constant value to every leaf of every tree of a specified forest
6161
*
62-
* \param forest_index Index of forest whos leaves will be modified
62+
* \param forest_index Index of forest whose leaves will be modified
6363
* \param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
6464
*/
6565
void AddToForest(int forest_index, double constant_value) {
@@ -68,7 +68,7 @@ class ForestContainer {
6868
/*!
6969
* \brief Multiply every leaf of every tree of a specified forest by a constant value
7070
*
71-
* \param forest_index Index of forest whos leaves will be modified
71+
* \param forest_index Index of forest whose leaves will be modified
7272
* \param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
7373
*/
7474
void MultiplyForest(int forest_index, double constant_multiple) {

src/py_stochtree.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ class ForestContainerCpp {
173173
}
174174
~ForestContainerCpp() {}
175175

176+
void CombineForests(py::array_t<int> forest_inds) {
177+
int num_forests = forest_inds.size();
178+
for (int j = 1; j < num_forests; j++) {
179+
forest_samples_->MergeForests(forest_inds.at(0), forest_inds.at(j));
180+
}
181+
}
182+
183+
void AddToForest(int forest_index, double constant_value) {
184+
forest_samples_->AddToForest(forest_index, constant_value);
185+
}
186+
187+
void MultiplyForest(int forest_index, double constant_multiple) {
188+
forest_samples_->MultiplyForest(forest_index, constant_multiple);
189+
}
190+
176191
int OutputDimension() {
177192
return forest_samples_->OutputDimension();
178193
}
@@ -2023,9 +2038,12 @@ PYBIND11_MODULE(stochtree_cpp, m) {
20232038

20242039
py::class_<RngCpp>(m, "RngCpp")
20252040
.def(py::init<int>());
2026-
2041+
20272042
py::class_<ForestContainerCpp>(m, "ForestContainerCpp")
20282043
.def(py::init<int,int,bool,bool>())
2044+
.def("CombineForests", &ForestContainerCpp::CombineForests)
2045+
.def("AddToForest", &ForestContainerCpp::AddToForest)
2046+
.def("MultiplyForest", &ForestContainerCpp::MultiplyForest)
20292047
.def("OutputDimension", &ForestContainerCpp::OutputDimension)
20302048
.def("NumTrees", &ForestContainerCpp::NumTrees)
20312049
.def("NumSamples", &ForestContainerCpp::NumSamples)

stochtree/forest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,63 @@ def set_root_leaves(
159159
else:
160160
self.forest_container_cpp.SetRootValue(forest_num, leaf_value)
161161

162+
def combine_forests(
163+
self, forest_inds: np.array
164+
) -> None:
165+
"""
166+
Collapse specified forests into a single forest
167+
168+
Parameters
169+
----------
170+
forest_inds : np.array
171+
Indices of forests to be combined (0-indexed).
172+
"""
173+
if not isinstance(forest_inds, np.ndarray):
174+
raise ValueError("forest_inds must be either a np.array")
175+
if not np.issubdtype(forest_inds.dtype, np.number):
176+
raise ValueError("forest_inds must be an integer-convertible np.array")
177+
forest_inds_sorted = np.sort(forest_inds)
178+
forest_inds_sorted = np.astype(forest_inds_sorted, int)
179+
self.forest_container_cpp.CombineForests(forest_inds_sorted)
180+
181+
def add_to_forest(
182+
self, forest_index: int, constant_value : float
183+
) -> None:
184+
"""
185+
Add a constant value to every leaf of every tree of a given forest
186+
187+
Parameters
188+
----------
189+
forest_index : int
190+
Index of forest whose leaves will be modified (0-indexed)
191+
constant_value : float
192+
Value to add to every leaf of every tree of the forest at `forest_index`
193+
"""
194+
if not isinstance(forest_index, int) and not isinstance(constant_value, (int, float)):
195+
raise ValueError("forest_index must be an integer and constant_multiple must be a float or int")
196+
if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples():
197+
raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container")
198+
self.forest_container_cpp.AddToForest(forest_index, constant_value)
199+
200+
def multiply_forest(
201+
self, forest_index: int, constant_multiple : float
202+
) -> None:
203+
"""
204+
Multiply every leaf of every tree of a given forest by constant value
205+
206+
Parameters
207+
----------
208+
forest_index : int
209+
Index of forest whose leaves will be modified (0-indexed)
210+
constant_multiple : float
211+
Value to multiply through by every leaf of every tree of the forest at `forest_index`
212+
"""
213+
if not isinstance(forest_index, int) and not isinstance(constant_multiple, (int, float)):
214+
raise ValueError("forest_index must be an integer and constant_multiple must be a float or int")
215+
if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples():
216+
raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container")
217+
self.forest_container_cpp.MultiplyForest(forest_index, constant_multiple)
218+
162219
def save_to_json_file(self, json_filename: str) -> None:
163220
"""
164221
Save the forests in the container to a JSON file.

test/R/testthat/test-forest-container.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ test_that("Univariate constant forest container", {
5757
# Add 1.0 to every tree in first forest
5858
forest_samples$add_to_forest(0, 1.0)
5959

60-
# Check that predictions are + num_trees
60+
# Check that predictions are += num_trees
6161
pred_expected <- pred + num_trees
6262
pred <- forest_samples$predict(forest_dataset)
6363

test/python/test_forest_container.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
3+
from stochtree import Dataset, ForestContainer
4+
5+
6+
class TestPredict:
7+
def test_constant_leaf_prediction(self):
8+
# Create dataset
9+
X = np.array(
10+
[[1.5, 8.7, 1.2],
11+
[2.7, 3.4, 5.4],
12+
[3.6, 1.2, 9.3],
13+
[4.4, 5.4, 10.4],
14+
[5.3, 9.3, 3.6],
15+
[6.1, 10.4, 4.4]]
16+
)
17+
n, p = X.shape
18+
num_trees = 10
19+
output_dim = 1
20+
forest_dataset = Dataset()
21+
forest_dataset.add_covariates(X)
22+
forest_samples = ForestContainer(num_trees, output_dim, True, False)
23+
24+
# Initialize a forest with constant root predictions
25+
forest_samples.add_sample(0.)
26+
27+
# Split the root of the first tree in the ensemble at X[,1] > 4.0
28+
# and then split the left leaf of the first tree in the ensemble at X[,2] > 4.0
29+
forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.)
30+
forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5)
31+
32+
# Store the predictions of the "original" forest before modifications
33+
pred_orig = forest_samples.predict(forest_dataset)
34+
35+
# Multiply first forest by 2.0
36+
forest_samples.multiply_forest(0, 2.0)
37+
38+
# Check that predictions are all double
39+
pred = forest_samples.predict(forest_dataset)
40+
pred_expected = pred_orig * 2.0
41+
42+
# Assertion
43+
np.testing.assert_almost_equal(pred, pred_expected)
44+
45+
# Add 1.0 to every tree in first forest
46+
forest_samples.add_to_forest(0, 1.0)
47+
48+
# Check that predictions are += num_trees
49+
pred_expected = pred + num_trees
50+
pred = forest_samples.predict(forest_dataset)
51+
52+
# Assertion
53+
np.testing.assert_almost_equal(pred, pred_expected)
54+
55+
# Initialize a new forest with constant root predictions
56+
forest_samples.add_sample(0.)
57+
58+
# Split the second forest as the first forest was split
59+
forest_samples.add_numeric_split(1, 0, 0, 0, 4.0, -5., 5.)
60+
forest_samples.add_numeric_split(1, 0, 1, 1, 4.0, -7.5, -2.5)
61+
62+
# Check that predictions are as expected
63+
pred_expected_new = np.c_[pred_expected, pred_orig]
64+
pred = forest_samples.predict(forest_dataset)
65+
66+
# Assertion
67+
np.testing.assert_almost_equal(pred, pred_expected_new)
68+
69+
# Combine second forest with the first forest
70+
forest_samples.combine_forests(np.array([0,1]))
71+
72+
# Check that predictions are as expected
73+
pred_expected_new = np.c_[pred_expected + pred_orig, pred_orig]
74+
pred = forest_samples.predict(forest_dataset)
75+
76+
# Assertion
77+
np.testing.assert_almost_equal(pred, pred_expected_new)

0 commit comments

Comments
 (0)